keras.js 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211
  1. /* jshint esversion: 6 */
  2. var keras = keras || {};
  3. var json = json || require('./json');
  4. keras.ModelFactory = class {
  5. match(context) {
  6. const stream = context.stream;
  7. const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
  8. if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  9. return true;
  10. }
  11. const tags = context.tags('json');
  12. if (tags.has('mxnet_version')) {
  13. return false;
  14. }
  15. if (tags.has('nodes') && tags.has('arg_nodes') && tags.has('heads')) {
  16. return false;
  17. }
  18. if (tags.has('modelTopology') && tags.get('format') !== 'graph-model') {
  19. return true;
  20. }
  21. if (tags.has('model_config') || (tags.has('class_name') && tags.has('config'))) {
  22. return true;
  23. }
  24. if (tags.has('[].weights') && tags.has('[].paths')) {
  25. return true;
  26. }
  27. return false;
  28. }
  29. open(context, host) {
  30. return host.require('./hdf5').then((hdf5) => {
  31. let format = 'Keras';
  32. let producer = '';
  33. let backend = '';
  34. let model_config = null;
  35. let rootGroup = null;
  36. const identifier = context.identifier;
  37. const weights = new keras.Weights();
  38. switch (identifier.split('.').pop().toLowerCase()) {
  39. case 'keras':
  40. case 'h5':
  41. case 'hd5':
  42. case 'hdf5':
  43. case 'model':
  44. case 'pb':
  45. case 'pth': {
  46. const buffer = context.stream.peek();
  47. const file = new hdf5.File(buffer);
  48. rootGroup = file.rootGroup;
  49. if (rootGroup.attribute('model_config') || rootGroup.attribute('layer_names')) {
  50. const model_config_json = rootGroup.attribute('model_config');
  51. if (model_config_json) {
  52. const reader = json.TextReader.create(model_config_json);
  53. model_config = reader.read();
  54. }
  55. backend = rootGroup.attribute('backend') || '';
  56. const version = rootGroup.attribute('keras_version') || '';
  57. format = format + (version ? ' v' + version : '');
  58. let model_weights_group = rootGroup.group('model_weights');
  59. if (!model_weights_group && rootGroup.attribute('layer_names')) {
  60. model_weights_group = rootGroup;
  61. }
  62. if (model_weights_group) {
  63. model_weights_group = new keras.Group(model_weights_group);
  64. for (const layer_name of model_weights_group.attribute('layer_names')) {
  65. const layer_weights = model_weights_group.group(layer_name);
  66. if (layer_weights) {
  67. const weight_names = layer_weights.attribute('weight_names');
  68. if (weight_names && weight_names.length > 0) {
  69. for (const weight_name of weight_names) {
  70. const weight = layer_weights.group(weight_name);
  71. if (weight && weight.value) {
  72. const variable = weight.value;
  73. const tensor = new keras.Tensor(weight_name, variable.type, variable.shape, variable.littleEndian, variable.data, '');
  74. if (model_config) {
  75. weights.add(layer_name, tensor);
  76. }
  77. else {
  78. const components = weight_name.split('/');
  79. components.pop();
  80. const name = (components.length == 0 || components[0] !== layer_name) ? [ layer_name ].concat(components).join('/') : components.join('/');
  81. weights.add(name, tensor);
  82. }
  83. }
  84. }
  85. }
  86. }
  87. }
  88. }
  89. }
  90. else {
  91. const attributes = new Set([ 'nb_layers' ]);
  92. if (Object.keys(rootGroup.attributes).filter((name) => !attributes.has(name)).length !== 0 || rootGroup.value !== null) {
  93. throw new keras.Error('File format is not HDF5 Weights');
  94. }
  95. format = 'HDF5 Weights';
  96. if (Object.keys(rootGroup.attributes).length === 0 && rootGroup.value === null &&
  97. rootGroup.groups.length == 1 && rootGroup.groups[0] &&
  98. Object.keys(rootGroup.groups[0].attributes).length === 0 && rootGroup.groups[0].value === null) {
  99. rootGroup = rootGroup.groups[0];
  100. }
  101. if (rootGroup.groups.every((group) => Object.keys(group.attributes).length === 0 && group.groups.length == 0 && group.value !== null)) {
  102. for (const group of rootGroup.groups) {
  103. const variable = group.value;
  104. const tensor = new keras.Tensor(group.name, variable.type, variable.shape, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
  105. weights.add('', tensor);
  106. }
  107. }
  108. else if (rootGroup.groups.every((group) => Object.keys(group.attributes).length === 0 && group.value === null)) {
  109. for (const group of rootGroup.groups) {
  110. const moduleName = group.attributes.name || group.name;
  111. for (const variableGroup of group.groups) {
  112. if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
  113. throw new keras.Error('Group is not HDF5 tensor variable.');
  114. }
  115. const variable = variableGroup.value;
  116. if (!variable) {
  117. throw new keras.Error('Variable value is not HDF5 tensor.');
  118. }
  119. const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
  120. const tensor = new keras.Tensor(name, variable.type, variable.shape, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
  121. weights.add(moduleName, tensor);
  122. }
  123. }
  124. }
  125. else if (rootGroup.groups.every((group) => group.value === null && group.groups.every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
  126. for (const group of rootGroup.groups) {
  127. const moduleName = group.attributes.name || group.name;
  128. for (const variableGroup of group.groups) {
  129. if (Object.keys(variableGroup.attributes).length !== 0 || variableGroup.groups.length !== 0) {
  130. throw new keras.Error('Variable format is not HDF5 Weights');
  131. }
  132. const variable = variableGroup.value;
  133. if (!variable) {
  134. throw new keras.Error('Variable value is not HDF5 Weights');
  135. }
  136. const name = moduleName ? [ moduleName, variableGroup.name ].join('/') : moduleName.name;
  137. const tensor = new keras.Tensor(name, variable.type, variable.shape, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
  138. weights.add(moduleName, tensor);
  139. }
  140. }
  141. }
  142. else {
  143. const walk = function(group) {
  144. if (Object.keys(group.attributes).length === 0 && group.value === null && group.groups.length > 0) {
  145. for (const subGroup of group.groups) {
  146. walk(subGroup);
  147. }
  148. }
  149. else if (Object.keys(group.attributes).length === 0 && group.value !== null && group.groups.length === 0) {
  150. const variable = group.value;
  151. const variableName = group.path;
  152. let moduleName = variableName;
  153. const parts = variableName.split('/');
  154. if (parts.length > 1) {
  155. parts.pop();
  156. moduleName = parts.join('/');
  157. }
  158. const tensor = new keras.Tensor(variableName, variable.type, variable.shape, variable.littleEndian, variable.type === 'string' ? variable.value : variable.data);
  159. weights.add(moduleName, tensor);
  160. }
  161. else {
  162. throw new keras.Error('Module group format is not HDF5 Weights');
  163. }
  164. };
  165. walk(rootGroup);
  166. }
  167. }
  168. break;
  169. }
  170. case 'json': {
  171. const buffer = context.stream.peek();
  172. const reader = json.TextReader.create(buffer);
  173. const root = reader.read();
  174. if (root && Array.isArray(root) && root.every((manifest) => Array.isArray(manifest.weights) && Array.isArray(manifest.paths))) {
  175. format = 'TensorFlow.js Weights';
  176. rootGroup = {};
  177. for (const manifest of root) {
  178. for (const weight of manifest.weights) {
  179. const tensor = new keras.Tensor(weight.name, weight.dtype, weight.shape, false, null, manifest.paths.join(';'));
  180. const parts = weight.name.split('/');
  181. parts.pop();
  182. const layer = parts.join('/');
  183. weights.add(layer, tensor);
  184. }
  185. }
  186. }
  187. else {
  188. if (root.keras_version) {
  189. const version = root.keras_version;
  190. format = format + (version ? (' v' + version) : '');
  191. }
  192. if (root.backend) {
  193. backend = root.backend;
  194. }
  195. model_config = root;
  196. if (model_config && model_config.modelTopology) {
  197. backend = model_config.modelTopology.backend;
  198. const version = model_config.modelTopology.keras_version;
  199. format = format + (version ? (' v' + version) : '');
  200. format = 'TensorFlow.js ' + (model_config.format ? model_config.format : format);
  201. producer = model_config.convertedBy || model_config.generatedBy || '';
  202. for (const manifest of model_config.weightsManifest) {
  203. for (const weight of manifest.weights) {
  204. const tensor = new keras.Tensor(weight.name, weight.dtype, weight.shape, false, null, manifest.paths.join(';'));
  205. weights.add('', tensor);
  206. }
  207. }
  208. model_config = model_config.modelTopology;
  209. }
  210. if (model_config.model_config) {
  211. model_config = model_config.model_config;
  212. }
  213. }
  214. break;
  215. }
  216. }
  217. if (!rootGroup && !model_config) {
  218. throw new keras.Error('\'model_config\' is not present.');
  219. }
  220. if (!rootGroup && !model_config.class_name) {
  221. throw new keras.Error('\'class_name\' is not present.');
  222. }
  223. return keras.Metadata.open(host).then((metadata) => {
  224. return new keras.Model(metadata, format, producer, backend, model_config, weights);
  225. });
  226. });
  227. }
  228. };
  229. keras.Model = class {
  230. constructor(metadata, format, producer, backend, config, weights) {
  231. this._format = format;
  232. this._backend = backend;
  233. this._producer = producer;
  234. this._graphs = [ new keras.Graph(metadata, config, weights) ];
  235. }
  236. get name() {
  237. return null;
  238. }
  239. get description() {
  240. return null;
  241. }
  242. get format() {
  243. return this._format;
  244. }
  245. get producer() {
  246. return this._producer;
  247. }
  248. get runtime() {
  249. return this._backend;
  250. }
  251. get graphs() {
  252. return this._graphs;
  253. }
  254. };
  255. keras.Graph = class {
  256. constructor(metadata, config, weights) {
  257. this._metadata = metadata;
  258. this._inputs = [];
  259. this._outputs = [];
  260. this._nodes = [];
  261. this._groups = false;
  262. if (config) {
  263. this._name = config.name || (config.config && config.config.name ? config.config.name : '');
  264. switch (config.class_name) {
  265. case 'AllCNN':
  266. case 'Sequential':
  267. this._loadSequential(config.config, weights, '', null, null);
  268. break;
  269. case 'Functional':
  270. case 'Model':
  271. this._loadModel(config.config, weights, '', null, null);
  272. break;
  273. default:
  274. throw new keras.Error('\'' + config.class_name + '\' is not supported.');
  275. }
  276. }
  277. else if (weights) {
  278. for (const layer of weights.keys()) {
  279. if (weights.get('', layer).length <= 6) {
  280. const node = new keras.Node(metadata, 'Weights', { name: layer }, [], [], '', weights);
  281. this._nodes.push(node);
  282. }
  283. }
  284. }
  285. }
  286. get name() {
  287. return this._name;
  288. }
  289. get groups() {
  290. return this._groups ? true : false;
  291. }
  292. get inputs() {
  293. return this._inputs;
  294. }
  295. get outputs() {
  296. return this._outputs;
  297. }
  298. get nodes() {
  299. return this._nodes;
  300. }
  301. _loadModel(config, weights, group, inputs, outputs) {
  302. if (group) {
  303. this._groups = true;
  304. }
  305. const nodeMap = new Map();
  306. if (config.layers) {
  307. for (const layer of config.layers) {
  308. if (layer.name) {
  309. if (!nodeMap.has(layer.name)) {
  310. nodeMap.set(layer.name, layer);
  311. layer._inputs = [];
  312. layer._outputs = [];
  313. }
  314. }
  315. }
  316. for (const layer of config.layers) {
  317. if (layer.inbound_nodes) {
  318. for (let inbound_node of layer.inbound_nodes) {
  319. const is_connection = (item) => {
  320. return Array.isArray(item) && (item.length === 3 || item.length === 4) && typeof item[0] === 'string';
  321. };
  322. // wrap
  323. if (is_connection(inbound_node)) {
  324. inbound_node = [ inbound_node ];
  325. }
  326. // unwrap
  327. if (Array.isArray(inbound_node) && inbound_node.every((array) => Array.isArray(array) && array.every((item) => is_connection(item)))) {
  328. inbound_node = inbound_node.flat();
  329. }
  330. for (const inbound_connection of inbound_node) {
  331. let inputName = inbound_connection[0];
  332. const inputNode = nodeMap.get(inputName);
  333. if (inputNode) {
  334. const inputIndex = inbound_connection[2];
  335. if (inputIndex != 0) {
  336. inputName += ':' + inputIndex.toString();
  337. }
  338. while (inputIndex >= inputNode._outputs.length) {
  339. inputNode._outputs.push('');
  340. }
  341. inputNode._outputs[inputIndex] = inputName;
  342. }
  343. layer._inputs.push(inputName);
  344. }
  345. }
  346. }
  347. }
  348. }
  349. const input_layers = config.input_layers;
  350. if (input_layers) {
  351. for (let i = 0; i < input_layers.length; i++) {
  352. const input_layer = input_layers[i];
  353. const name = input_layer[0];
  354. let type = null;
  355. const node = nodeMap.get(name);
  356. if (node && node.class_name == 'InputLayer') {
  357. type = this._getInputType(node);
  358. nodeMap.delete(name);
  359. }
  360. if (inputs && i < inputs.length) {
  361. if (config.layers) {
  362. for (const layer of config.layers) {
  363. if (layer._inputs) {
  364. layer._inputs = layer._inputs.map((input) => {
  365. return input === name ? inputs[i] : input;
  366. });
  367. }
  368. }
  369. }
  370. }
  371. else {
  372. this._inputs.push(new keras.Parameter(name, true, [ new keras.Argument(name, type, null) ]));
  373. }
  374. }
  375. }
  376. const inputMap = new Map();
  377. const output_layers = config.output_layers;
  378. if (output_layers) {
  379. for (let j = 0; j < output_layers.length; j++) {
  380. const output_layer = output_layers[j];
  381. let outputName = output_layer[0];
  382. const outputNode = nodeMap.get(outputName);
  383. let addGraphOutput = true;
  384. if (outputs && j < outputs.length) {
  385. inputMap.set(outputName, outputs[j]);
  386. outputName = outputs[j];
  387. addGraphOutput = false;
  388. }
  389. if (outputNode) {
  390. const outputIndex = output_layer[2];
  391. if (outputIndex != 0) {
  392. outputName += ':' + outputIndex.toString();
  393. }
  394. while (outputIndex >= outputNode._outputs.length) {
  395. outputNode._outputs.push('');
  396. }
  397. outputNode._outputs[outputIndex] = outputName;
  398. }
  399. if (addGraphOutput) {
  400. this._outputs.push(new keras.Parameter(outputName, true, [ new keras.Argument(outputName, null, null) ]));
  401. }
  402. }
  403. }
  404. if (config.layers) {
  405. for (const layer of config.layers) {
  406. if (nodeMap.has(layer.name)) {
  407. this._loadNode(layer, layer._inputs, layer._outputs, weights, group, inputMap);
  408. }
  409. }
  410. }
  411. }
  412. _loadSequential(config, weights, group, inputs, outputs) {
  413. if (group) {
  414. this._groups = true;
  415. }
  416. const inputName = 'input';
  417. let inputType = null;
  418. let argument = inputName;
  419. let index = 0;
  420. const layers = config.layers ? config.layers : config;
  421. for (const layer of layers) {
  422. let name = index.toString();
  423. let nodeInputs = [ argument ];
  424. if (index == 0) {
  425. if (inputs && inputs.length > 0) {
  426. nodeInputs = [ inputs[0] ];
  427. }
  428. else {
  429. inputType = this._getInputType(layer);
  430. }
  431. }
  432. index++;
  433. if (layer.config && layer.config.name) {
  434. name = layer.config.name;
  435. }
  436. argument = name;
  437. let nodeOutputs = [ argument ];
  438. if (index == layers.length) {
  439. if (outputs && outputs.length > 0) {
  440. nodeOutputs = [ outputs[0] ];
  441. argument = null;
  442. }
  443. }
  444. this._loadNode(layer, nodeInputs, nodeOutputs, weights, group);
  445. }
  446. if (!inputs) {
  447. this._inputs.push(new keras.Parameter(inputName, true, [ new keras.Argument(inputName, inputType, null) ]));
  448. }
  449. if (argument) {
  450. this._outputs.push(new keras.Parameter(argument, true, [ new keras.Argument(argument, null, null) ]));
  451. }
  452. }
  453. _loadNode(layer, inputs, outputs, weights, group, inputMap) {
  454. const class_name = layer.class_name;
  455. switch (class_name) {
  456. case 'Sequential': {
  457. const name = layer.name || (layer.config ? layer.config.name : '');
  458. this._loadSequential(layer.config, weights, (group ? group + '/' : '') + name, inputs, outputs);
  459. break;
  460. }
  461. case 'Functional':
  462. case 'Model': {
  463. const name = layer.name || (layer.config ? layer.config.name : '');
  464. this._loadModel(layer.config, weights, (group ? group + '/' : '') + name, inputs, outputs);
  465. break;
  466. }
  467. default: {
  468. inputs = inputs.map((input) => inputMap && inputMap.has(input) ? inputMap.get(input) : input);
  469. const node = new keras.Node(this._metadata, class_name, layer.config, inputs, outputs, group, weights);
  470. this._nodes.push(node);
  471. break;
  472. }
  473. }
  474. }
  475. _getInputType(layer) {
  476. if (layer && layer.config) {
  477. let dataType = '?';
  478. let shape = [];
  479. const config = layer.config;
  480. if (config.dtype) {
  481. dataType = config.dtype;
  482. delete config.dtype;
  483. }
  484. if (config.batch_input_shape) {
  485. shape = config.batch_input_shape.map(s => s == null ? '?' : s);
  486. delete config.batch_input_shape;
  487. }
  488. return new keras.TensorType(dataType, new keras.TensorShape(shape));
  489. }
  490. return null;
  491. }
  492. };
  493. keras.Parameter = class {
  494. constructor(name, visible, args) {
  495. this._name = name;
  496. this._visible = visible;
  497. this._arguments = args;
  498. }
  499. get name() {
  500. return this._name;
  501. }
  502. get visible() {
  503. return this._visible;
  504. }
  505. get arguments() {
  506. return this._arguments;
  507. }
  508. };
  509. keras.Argument = class {
  510. constructor(name, type, initializer) {
  511. if (typeof name !== 'string') {
  512. throw new keras.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  513. }
  514. this._name= name;
  515. this._type = type || null;
  516. this._initializer = initializer || null;
  517. }
  518. get name() {
  519. return this._name;
  520. }
  521. get type() {
  522. if (this._initializer) {
  523. return this._initializer.type;
  524. }
  525. return this._type;
  526. }
  527. get initializer() {
  528. return this._initializer;
  529. }
  530. };
  531. keras.Node = class {
  532. constructor(metadata, type, config, inputs, outputs, group, weights) {
  533. this._group = group || '';
  534. this._metadata = metadata;
  535. this._type = type;
  536. const name = config && config.name ? config.name : '';
  537. this._name = (this._group ? this._group + '/' : '') + name;
  538. this._inputs = [];
  539. this._outputs = [];
  540. this._attributes = [];
  541. let names = [ name ];
  542. if ((type == 'Bidirectional' || type == 'TimeDistributed') && (config && config.layer)) {
  543. const inner = config.layer;
  544. delete config.layer;
  545. this._inner = new keras.Node(this._metadata, inner.class_name, inner.config, [], [], null, null);
  546. if (type == 'Bidirectional' && inner.config.name) {
  547. names = [ name + '/forward_' + inner.config.name, name + '/backward_' + inner.config.name ];
  548. if (!group) {
  549. group = name;
  550. }
  551. }
  552. }
  553. const initializers = {};
  554. if (weights) {
  555. for (const name of names) {
  556. for (const initializer of weights.get(group, name)) {
  557. inputs.push(initializer.name);
  558. initializers[initializer.name] = initializer;
  559. }
  560. }
  561. }
  562. if (config) {
  563. for (const name of Object.keys(config)) {
  564. const value = config[name];
  565. if (name != 'name' && value != null) {
  566. this._attributes.push(new keras.Attribute(metadata.attribute(this.type, name), name, value));
  567. }
  568. }
  569. }
  570. const schema = this._metadata.type(this.type);
  571. const innerType = this.inner ? this.inner.type : null;
  572. const innerSchema = innerType ? this._metadata.type(innerType) : null;
  573. let inputIndex = 0;
  574. while (inputs.length > 0) {
  575. let variadic = false;
  576. let inputName = null;
  577. let visible = true;
  578. if (!innerSchema || inputIndex == 0) {
  579. if (schema && schema.inputs && inputIndex < schema.inputs.length) {
  580. const input = schema.inputs[inputIndex];
  581. inputName = input.name;
  582. if (type === 'BatchNormalization' && inputName === 'gamma' && config.scale === false) {
  583. inputIndex++;
  584. continue;
  585. }
  586. visible = input.visible == false ? false : true;
  587. if (schema.inputs[inputIndex].option == 'variadic') {
  588. variadic = true;
  589. }
  590. }
  591. }
  592. else {
  593. switch (type) {
  594. case 'Bidirectional': {
  595. let innerIndex = inputIndex;
  596. if (innerSchema && innerSchema.inputs) {
  597. if (innerIndex < innerSchema.inputs.length) {
  598. inputName = 'forward_' + innerSchema.inputs[innerIndex].name;
  599. }
  600. else {
  601. innerIndex = innerIndex - innerSchema.inputs.length + 1;
  602. if (innerIndex < innerSchema.inputs.length) {
  603. inputName = 'backward_' + innerSchema.inputs[innerIndex].name;
  604. }
  605. }
  606. }
  607. visible = false;
  608. break;
  609. }
  610. case 'TimeDistributed':
  611. if (innerSchema && innerSchema.inputs && inputIndex < innerSchema.inputs.length) {
  612. inputName = innerSchema.inputs[inputIndex].name;
  613. }
  614. break;
  615. }
  616. }
  617. const input = !variadic ? [ inputs.shift() ] : inputs.splice(0, inputs.length);
  618. const inputArguments = input.map((id) => {
  619. return new keras.Argument(id, null, initializers[id]);
  620. });
  621. if (!inputName && inputArguments.length == 1 && inputArguments[0].initializer && inputArguments[0].initializer.name) {
  622. if (names.length === 1 && names[0] === '') {
  623. inputName = inputArguments[0].initializer.name;
  624. }
  625. else {
  626. const parts = inputArguments[0].initializer.name.split('/').pop().split(':').shift().split('_');
  627. const inputName1 = parts.pop();
  628. const inputName2 = parts.length > 0 ? [ parts.pop(), inputName1 ].join('_') : '';
  629. const inputNames = new Set([ 'recurrent_kernel', 'running_mean', 'running_std', 'moving_mean', 'moving_variance', 'depthwise_filter', 'pointwise_filter' ]);
  630. inputName = inputNames.has(inputName2) ? inputName2 : inputName1;
  631. }
  632. }
  633. this._inputs.push(new keras.Parameter(inputName || inputIndex.toString(), visible, inputArguments));
  634. inputIndex++;
  635. }
  636. this._outputs = outputs.map((output, outputIndex) => {
  637. const outputName =
  638. (schema && schema.outputs && outputIndex < schema.outputs.length && schema.outputs[outputIndex] && schema.outputs[outputIndex].name) ?
  639. schema.outputs[outputIndex].name :
  640. outputIndex.toString();
  641. return new keras.Parameter(outputName, true, [ new keras.Argument(output, null, null) ]);
  642. });
  643. }
  644. get type() {
  645. return this._type;
  646. }
  647. get metadata() {
  648. return this._metadata.type(this._type);
  649. }
  650. get name() {
  651. return this._name;
  652. }
  653. get group() {
  654. return this._group;
  655. }
  656. get inputs() {
  657. return this._inputs;
  658. }
  659. get outputs() {
  660. return this._outputs;
  661. }
  662. get attributes() {
  663. return this._attributes;
  664. }
  665. get inner() {
  666. return this._inner;
  667. }
  668. };
  669. keras.Attribute = class {
  670. constructor(schema, name, value) {
  671. this._name = name;
  672. this._value = value;
  673. if (typeof value == 'object' && value.class_name && value.config) {
  674. this._value = keras.Attribute._convert(value);
  675. }
  676. switch (name) {
  677. case 'trainable':
  678. this._type = 'boolean';
  679. this._visible = false;
  680. break;
  681. case 'dtype':
  682. this._visible = false;
  683. break;
  684. default: {
  685. if (schema) {
  686. if (schema.type) {
  687. this._type = schema.type;
  688. }
  689. if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  690. this._visible = false;
  691. }
  692. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  693. if (keras.Attribute._isEquivalent(schema.default, value)) {
  694. this._visible = false;
  695. }
  696. }
  697. }
  698. break;
  699. }
  700. }
  701. }
  702. get name() {
  703. return this._name;
  704. }
  705. get type() {
  706. return this._type;
  707. }
  708. get value() {
  709. return this._value;
  710. }
  711. get visible() {
  712. return this._visible == false ? false : true;
  713. }
  714. static _convert(value) {
  715. if (Array.isArray(value) || value !== Object(value)) {
  716. return value;
  717. }
  718. const obj = {};
  719. if (value.class_name) {
  720. obj.__type__ = value.class_name;
  721. }
  722. for (const key of Object.keys(value.config)) {
  723. obj[key] = keras.Attribute._convert(value.config[key]);
  724. }
  725. return obj;
  726. }
  727. static _isEquivalent(a, b) {
  728. if (a === b) {
  729. return a !== 0 || 1 / a === 1 / b;
  730. }
  731. if (a == null || b == null) {
  732. return false;
  733. }
  734. if (a !== a) {
  735. return b !== b;
  736. }
  737. const type = typeof a;
  738. if (type !== 'function' && type !== 'object' && typeof b != 'object') {
  739. return false;
  740. }
  741. const className = toString.call(a);
  742. if (className !== toString.call(b)) {
  743. return false;
  744. }
  745. switch (className) {
  746. case '[object RegExp]':
  747. case '[object String]':
  748. return '' + a === '' + b;
  749. case '[object Number]':
  750. if (+a !== +a) {
  751. return +b !== +b;
  752. }
  753. return +a === 0 ? 1 / +a === 1 / b : +a === +b;
  754. case '[object Date]':
  755. case '[object Boolean]':
  756. return +a === +b;
  757. case '[object Array]': {
  758. let length = a.length;
  759. if (length !== b.length) {
  760. return false;
  761. }
  762. while (length--) {
  763. if (!keras.Attribute._isEquivalent(a[length], b[length])) {
  764. return false;
  765. }
  766. }
  767. return true;
  768. }
  769. }
  770. const keys = Object.keys(a);
  771. let size = keys.length;
  772. if (Object.keys(b).length != size) {
  773. return false;
  774. }
  775. while (size--) {
  776. const key = keys[size];
  777. if (!(Object.prototype.hasOwnProperty.call(b, key) && keras.Attribute._isEquivalent(a[key], b[key]))) {
  778. return false;
  779. }
  780. }
  781. return true;
  782. }
  783. };
  784. keras.Tensor = class {
  785. constructor(name, type, shape, littleEndian, data, reference) {
  786. this._name = name;
  787. this._type = new keras.TensorType(type, new keras.TensorShape(shape));
  788. this._littleEndian = littleEndian;
  789. this._data = data;
  790. this._reference = reference;
  791. }
  792. get kind() {
  793. return 'Weights';
  794. }
  795. get name() {
  796. return this._name;
  797. }
  798. get type() {
  799. return this._type;
  800. }
  801. get reference() {
  802. return this._reference;
  803. }
  804. get state() {
  805. return this._context().state;
  806. }
  807. get value() {
  808. const context = this._context();
  809. if (context.state) {
  810. return null;
  811. }
  812. context.limit = Number.MAX_SAFE_INTEGER;
  813. return this._decode(context, 0);
  814. }
  815. toString() {
  816. const context = this._context();
  817. if (context.state) {
  818. return '';
  819. }
  820. context.limit = 10000;
  821. const value = this._decode(context, 0);
  822. return keras.Tensor._stringify(value, '', ' ');
  823. }
  824. _context() {
  825. const context = {};
  826. context.index = 0;
  827. context.count = 0;
  828. context.state = null;
  829. if (this._reference) {
  830. context.state = 'Tensor reference not implemented.';
  831. return context;
  832. }
  833. if (!this._data) {
  834. context.state = 'Tensor data is empty.';
  835. return context;
  836. }
  837. switch (this._type.dataType) {
  838. case 'boolean':
  839. case 'float16':
  840. case 'float32':
  841. case 'float64':
  842. case 'uint8':
  843. case 'int32':
  844. case 'int64':
  845. context.dataType = this._type.dataType;
  846. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  847. context.littleEndian = this._littleEndian;
  848. break;
  849. case 'string':
  850. context.dataType = this._type.dataType;
  851. context.data = this._data;
  852. break;
  853. default:
  854. context.state = 'Tensor data type is not supported.';
  855. break;
  856. }
  857. context.shape = this._type.shape.dimensions;
  858. return context;
  859. }
  860. _decode(context, dimension) {
  861. const shape = context.shape.length !== 0 ? context.shape : [ 1 ];
  862. const results = [];
  863. const size = shape[dimension];
  864. const littleEndian = context.littleEndian;
  865. if (dimension == shape.length - 1) {
  866. for (let i = 0; i < size; i++) {
  867. if (context.count > context.limit) {
  868. results.push(null);
  869. return results;
  870. }
  871. switch (context.dataType) {
  872. case 'float16':
  873. results.push(context.data.getFloat16(context.index, littleEndian));
  874. context.index += 2;
  875. break;
  876. case 'float32':
  877. results.push(context.data.getFloat32(context.index, littleEndian));
  878. context.index += 4;
  879. break;
  880. case 'float64':
  881. results.push(context.data.getFloat64(context.index, littleEndian));
  882. context.index += 8;
  883. break;
  884. case 'boolean':
  885. results.push(context.data.getInt8(context.index) !== 0);
  886. context.index += 1;
  887. break;
  888. case 'uint8':
  889. results.push(context.data.getUint8(context.index));
  890. context.index += 1;
  891. break;
  892. case 'int32':
  893. results.push(context.data.getInt32(context.index, littleEndian));
  894. context.index += 4;
  895. break;
  896. case 'int64':
  897. results.push(context.data.getInt64(context.index, littleEndian));
  898. context.index += 8;
  899. break;
  900. case 'string':
  901. results.push(context.data[context.index]);
  902. context.index++;
  903. break;
  904. }
  905. context.count++;
  906. }
  907. }
  908. else {
  909. for (let j = 0; j < size; j++) {
  910. if (context.count > context.limit) {
  911. results.push(null);
  912. return results;
  913. }
  914. results.push(this._decode(context, dimension + 1));
  915. }
  916. }
  917. if (context.shape.length == 0) {
  918. return results[0];
  919. }
  920. return results;
  921. }
  922. static _stringify(value, indentation, indent) {
  923. if (Array.isArray(value)) {
  924. const result = [];
  925. result.push(indentation + '[');
  926. const items = value.map((item) => keras.Tensor._stringify(item, indentation + indent, indent));
  927. if (items.length > 0) {
  928. result.push(items.join(',\n'));
  929. }
  930. result.push(indentation + ']');
  931. return result.join('\n');
  932. }
  933. if (value === null) {
  934. return indentation + '...';
  935. }
  936. if (typeof value == 'string') {
  937. return indentation + '"' + value + '"';
  938. }
  939. if (value == Infinity) {
  940. return indentation + 'Infinity';
  941. }
  942. if (value == -Infinity) {
  943. return indentation + '-Infinity';
  944. }
  945. if (isNaN(value)) {
  946. return indentation + 'NaN';
  947. }
  948. return indentation + value.toString();
  949. }
  950. };
  951. keras.TensorType = class {
  952. constructor(dataType, shape) {
  953. this._dataType = dataType;
  954. this._shape = shape;
  955. }
  956. get dataType() {
  957. return this._dataType;
  958. }
  959. get shape() {
  960. return this._shape;
  961. }
  962. toString() {
  963. return this._dataType + this._shape.toString();
  964. }
  965. };
  966. keras.TensorShape = class {
  967. constructor(dimensions) {
  968. this._dimensions = dimensions;
  969. }
  970. get dimensions() {
  971. return this._dimensions;
  972. }
  973. toString() {
  974. return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
  975. }
  976. };
  977. keras.Metadata = class {
  978. static open(host) {
  979. if (keras.Metadata._metadata) {
  980. return Promise.resolve(keras.Metadata._metadata);
  981. }
  982. return host.request(null, 'keras-metadata.json', 'utf-8').then((data) => {
  983. keras.Metadata._metadata = new keras.Metadata(data);
  984. return keras.Metadata._metadata;
  985. }).catch(() => {
  986. keras.Metadata._metadata = new keras.Metadata(null);
  987. return keras.Metadata._metadatas;
  988. });
  989. }
  990. constructor(data) {
  991. this._map = new Map();
  992. this._attributeCache = new Map();
  993. if (data) {
  994. const items = JSON.parse(data);
  995. if (items) {
  996. for (const item of items) {
  997. if (item.name && item.schema) {
  998. item.schema.name = item.name;
  999. this._map.set(item.name, item.schema);
  1000. }
  1001. }
  1002. }
  1003. }
  1004. }
  1005. type(name) {
  1006. return this._map.get(name);
  1007. }
  1008. attribute(type, name) {
  1009. const key = type + ':' + name;
  1010. if (!this._attributeCache.has(key)) {
  1011. const schema = this.type(type);
  1012. if (schema && schema.attributes && schema.attributes.length > 0) {
  1013. for (const attribute of schema.attributes) {
  1014. this._attributeCache.set(type + ':' + attribute.name, attribute);
  1015. }
  1016. }
  1017. if (!this._attributeCache.has(key)) {
  1018. this._attributeCache.set(key, null);
  1019. }
  1020. }
  1021. return this._attributeCache.get(key);
  1022. }
  1023. };
  1024. keras.Group = class {
  1025. constructor(group) {
  1026. this._group = group;
  1027. }
  1028. attribute(name) {
  1029. let value = this._group.attribute(name);
  1030. if (!value) {
  1031. if (this._group.attribute(name + '0')) {
  1032. let index = 0;
  1033. value = [];
  1034. for (;;) {
  1035. const chunk = this._group.attribute(name + index.toString());
  1036. if (!chunk) {
  1037. break;
  1038. }
  1039. value = value.concat(chunk);
  1040. index++;
  1041. }
  1042. }
  1043. }
  1044. return value;
  1045. }
  1046. group(name) {
  1047. const value = this._group.group(name);
  1048. if (value) {
  1049. return new keras.Group(value);
  1050. }
  1051. return null;
  1052. }
  1053. get value() {
  1054. return this._group.value;
  1055. }
  1056. };
  1057. keras.Weights = class {
  1058. constructor() {
  1059. this._map = new Map();
  1060. }
  1061. add(layer_name, tensor) {
  1062. if (!this._map.has(layer_name)) {
  1063. this._map.set(layer_name, []);
  1064. }
  1065. this._map.get(layer_name).push(tensor);
  1066. }
  1067. get(group, name) {
  1068. if (group) {
  1069. const list = this._map.get(group.split('/').shift());
  1070. if (list) {
  1071. const match1 = list.filter((tensor) => tensor.name.startsWith(name + '/'));
  1072. if (match1.length > 0) {
  1073. return match1;
  1074. }
  1075. const match2 = list.filter((tensor) => tensor.name.startsWith(group + '/' + name + '/'));
  1076. if (match2.length > 0) {
  1077. return match2;
  1078. }
  1079. }
  1080. }
  1081. else {
  1082. const match1 = this._map.get(name);
  1083. if (match1 && match1.length > 0) {
  1084. return match1;
  1085. }
  1086. const match2 = this._map.get('');
  1087. if (match2 && match2.length > 0) {
  1088. const match3 = match2.filter((tensor) => tensor.name.startsWith((group ? group + '/' : '') + name + '/'));
  1089. if (match3.length > 0) {
  1090. return match3;
  1091. }
  1092. }
  1093. }
  1094. return [];
  1095. }
  1096. keys() {
  1097. return this._map.keys();
  1098. }
  1099. };
  1100. keras.Error = class extends Error {
  1101. constructor(message) {
  1102. super(message);
  1103. this.name = 'Error loading Keras model.';
  1104. }
  1105. };
  1106. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1107. module.exports.ModelFactory = keras.ModelFactory;
  1108. }