| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488 |
- import * as json from './json.js';
- import * as python from './python.js';
- const keras = {};
- const tfjs = {};
- keras.ModelFactory = class {
- async match(context) {
- const identifier = context.identifier;
- const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
- const group = await context.peek('hdf5');
- if (group && group.attributes && group.attributes.get('CLASS') !== 'hickle') {
- if (identifier === 'model.weights.h5') {
- return context.set('keras.model.weights.h5', group);
- }
- if (identifier === 'parameter.h5') {
- return context.set('hdf5.parameter.h5', group);
- }
- return context.set('keras.h5', group);
- }
- const json = await context.peek('json');
- if (json) {
- if (json.mxnet_version || (json.nodes && json.arg_nodes && json.heads)) {
- return null;
- }
- if (json.model_config || (json.class_name && json.config)) {
- return context.set('keras.config.json', json);
- }
- if (identifier === 'metadata.json' && json.keras_version) {
- return context.set('keras.metadata.json', json);
- }
- }
- const container = await tfjs.Container.open(context);
- if (container) {
- return context.set('tfjs', container);
- }
- const pickle = await context.peek('pkl');
- if (pickle && pickle.__class__ &&
- pickle.__class__.__module__ === 'keras.engine.sequential' &&
- pickle.__class__.__name__ === 'Sequential') {
- return context.set('tfjs.pickle', pickle);
- }
- // model.weights.npz
- const entries = await context.peek('npz');
- const regex = /^(__root__|layers\/.+|_layer_checkpoint_dependencies\/.+)\.npy$/;
- if (entries instanceof Map && entries.size > 0 && Array.from(entries).every(([name]) => regex.test(name))) {
- return context.set('keras.model.weights.npz', entries);
- }
- // keras_metadata.pb
- if (extension === 'pb' && context.stream && context.stream.length > 16) {
- const tags = await context.tags('pb');
- if (tags.size === 1 && tags.get(1) === 2) {
- const stream = context.stream;
- const buffer = stream.peek(Math.min(stream.length, 1024));
- const content = String.fromCharCode.apply(null, buffer);
- if (/root"/.test(content) && /\{\s*"class_name"\s*:/.test(content)) {
- return context.set('keras.pb.SavedMetadata');
- }
- }
- }
- return null;
- }
- filter(context, match) {
- if (context.type === 'keras.metadata.json' && (match.type === 'keras.config.json' || match.type === 'keras.model.weights.h5' || match.type === 'keras.model.weights.npz')) {
- return false;
- }
- if (context.type === 'keras.config.json' && (match.type === 'keras.model.weights.h5' || match.type === 'keras.model.weights.npz')) {
- return false;
- }
- if (context.type === 'tfjs' && match.type === 'tf.tfjs.weights') {
- return false;
- }
- return true;
- }
- async open(context) {
- const request_json = async (context, name) => {
- try {
- context = await context.fetch(name);
- } catch {
- return null;
- }
- return await context.read('json');
- };
- const _create_config = (weights_store) => {
- const config = {};
- config.class_name = 'Model';
- config.config = {};
- config.config.layers = [];
- const snake_to_pascal_case = (name) => {
- return name.replace(/(^|_|\d)([a-z])/g, (match, p1, p2) => p1 === '_' ? p2.toUpperCase() : p1 + p2.toUpperCase());
- };
- for (const [name, value] of weights_store) {
- const layer = {};
- layer.name = name;
- layer.class_name = name.split('/').pop().replace(/_[0-9]+$/, '');
- layer.class_name = snake_to_pascal_case(layer.class_name);
- layer.config = {};
- layer.config.name = name;
- layer._trainable_variables = value;
- config.config.layers.push(layer);
- }
- return config;
- };
- const _load_state = (trackable, weights_store, assets_store, inner_path) => {
- inner_path = inner_path || '';
- if (trackable && trackable.config && Array.isArray(trackable.config.layers)) {
- /* eslint-disable no-use-before-define */
- _load_container_state(trackable, weights_store, assets_store, inner_path ? `${inner_path}/layers` : 'layers');
- /* eslint-enable no-use-before-define */
- } else {
- const weights = weights_store.get(inner_path);
- if (weights) {
- trackable._trainable_variables = weights;
- }
- }
- };
- const _load_container_state = (container, weights_store, assets_store, inner_path) => {
- const used_names = new Map();
- for (const trackable of container.config.layers) {
- const pascal_to_snake_case = (name) => {
- name = name.replace(/\W+/g, "");
- name = name.replace(/(.)([A-Z][a-z]+)/g, (match, p1, p2) => `${p1}_${p2}`);
- name = name.replace(/([a-z])([A-Z])/g, (match, p1, p2) => `${p1}_${p2}`);
- return name.toLowerCase();
- };
- let name = pascal_to_snake_case(trackable.class_name);
- if (used_names.has(name)) {
- const next = used_names.get(name) + 1;
- used_names.set(name, next);
- name = `${name}_${next}`;
- } else {
- used_names.set(name, 0);
- }
- _load_state(trackable, weights_store, assets_store, `${inner_path}/${name}`);
- }
- };
- const read_weights_hdf5 = (group) => {
- const weights_store = new Map();
- const stack = [[group, '']];
- while (stack.length > 0) {
- const [group, path] = stack.pop();
- if (group.groups instanceof Map) {
- const checkpoint = group.groups.get('layers') || group.groups.get('_layer_checkpoint_dependencies');
- if (checkpoint) {
- for (const [key, layer] of checkpoint.groups) {
- const name = `${path ? `${path}/` : ''}layers/${key}`;
- stack.push([layer, name]);
- const values = [];
- for (const vars of layer.groups) {
- for (const [name, group] of vars[1].groups) {
- const variable = group.value;
- if (variable) {
- const layout = variable.littleEndian ? '<' : '>';
- const tensor = new keras.Tensor(name, variable.shape, variable.type, null, null, layout, variable.data);
- values.push(tensor);
- }
- }
- }
- if (values.length > 0) {
- weights_store.set(name, values);
- }
- }
- }
- }
- }
- return weights_store;
- };
- const read_weights_numpy = (entries) => {
- const weights_store = new Map();
- for (const [path, array] of entries) {
- const file = path.split('/').map((name) => name === '_layer_checkpoint_dependencies' ? 'layers' : name).join('/');
- if (file.endsWith('.npy') && file.startsWith('layers/')) {
- if (array.dtype.name === 'object' && array.shape.length === 0 && Array.isArray(array.data) && array.data.length === 1) {
- const values = Object.values(array.data[0]).map((array) => {
- const stride = array.strides.map((stride) => stride / array.itemsize);
- const dataType = array.dtype.__name__;
- const values = dataType === 'string' || dataType === 'object' ? array.flatten().tolist() : array.tobytes();
- const encoding = dataType === 'string' || dataType === 'object' ? '|' : array.dtype.byteorder;
- return new keras.Tensor('', array.shape, dataType, stride, null, encoding, values);
- });
- if (values.length > 0) {
- const name = file.replace(/\.npy$/, '');
- weights_store.set(name, values);
- }
- }
- }
- }
- return weights_store;
- };
- const request_weights = async (context) => {
- const formats = [
- ['model.weights.h5', 'hdf5', read_weights_hdf5],
- ['model.weights.npz', 'npz', read_weights_numpy],
- ];
- for (const [name, type, callback] of formats) {
- let content = null;
- try {
- /* eslint-disable no-await-in-loop */
- content = await context.fetch(name);
- /* eslint-enable no-await-in-loop */
- } catch {
- // continue regardless of error
- }
- if (content) {
- /* eslint-disable no-await-in-loop */
- const obj = await content.peek(type);
- /* eslint-enable no-await-in-loop */
- if (obj) {
- return callback(obj);
- }
- }
- }
- return new Map();
- };
- const open_model = async (format, producer, backend, config, weights) => {
- const metadata = await context.metadata('keras-metadata.json');
- return new keras.Model(metadata, format, producer, backend, config, weights);
- };
- switch (context.type) {
- case 'keras.config.json': {
- const obj = context.value;
- const config = obj.model_config ? obj.model_config : obj;
- const backend = obj.backend || '';
- let version = obj.keras_version ? obj.keras_version : null;
- if (!version) {
- const metadata = await request_json(context, 'metadata.json');
- if (metadata && metadata.keras_version) {
- version = metadata.keras_version;
- }
- }
- const format = `Keras${version ? ` v${version}` : ''}`;
- const weights_store = await request_weights(context);
- _load_state(config, weights_store);
- return open_model(format, '', backend, config, null);
- }
- case 'keras.model.weights.h5': {
- const group = context.value;
- const weights_store = read_weights_hdf5(group);
- const metadata = await request_json(context, 'metadata.json');
- let config = await request_json(context, 'config.json');
- const name = config ? 'Keras' : 'Keras Weights';
- const format = name + (metadata && metadata.keras_version ? ` v${metadata.keras_version}` : '');
- if (config) {
- _load_state(config, weights_store);
- } else {
- config = _create_config(weights_store);
- }
- return await open_model(format, '', '', config, null);
- }
- case 'keras.model.weights.npz': {
- const entries = context.value;
- const weights_store = read_weights_numpy(entries);
- const metadata = await request_json(context, 'metadata.json');
- let config = await request_json(context, 'config.json');
- const name = config ? 'Keras' : 'Keras Weights';
- const format = name + (metadata && metadata.keras_version ? ` v${metadata.keras_version}` : '');
- if (config) {
- _load_state(config, weights_store);
- } else {
- config = _create_config(weights_store);
- }
- return await open_model(format, '', '', config, null);
- }
- case 'keras.metadata.json': {
- const metadata = context.value;
- let config = await request_json(context, 'config.json');
- const name = config ? 'Keras' : 'Keras Weights';
- const format = name + (metadata.keras_version ? ` v${metadata.keras_version}` : '');
- const weights_store = await request_weights(context);
- if (!config && (!weights_store || weights_store.size === 0)) {
- throw new keras.Error("'config.json' or 'model.weights.*' not present.");
- }
- if (config) {
- _load_state(config, weights_store);
- } else {
- config = _create_config(weights_store);
- }
- return await open_model(format, '', '', config, null);
- }
- case 'hdf5.parameter.h5':
- case 'keras.h5': {
- const find_root_group = (root_group) => {
- const kerasmodel = root_group.group('model/kerasmodel');
- if (kerasmodel && kerasmodel.attributes.has('model_config')) {
- return kerasmodel;
- }
- return root_group;
- };
- const read_model_config = (group) => {
- if (group.attributes.has('model_config')) {
- const buffer = group.attributes.get('model_config');
- const reader = json.TextReader.open(buffer);
- if (reader) {
- return reader.read();
- }
- }
- return null;
- };
- const load_attributes_from_hdf5_group = (group, name) => {
- if (group.attributes.has(name)) {
- return group.attributes.get(name);
- }
- if (group.attributes.has(`${name}0`)) {
- let index = 0;
- let value = [];
- while (group.attributes.has(name + index.toString())) {
- const chunk = group.attributes.get(name + index.toString());
- value = value.concat(chunk);
- index++;
- }
- return value;
- }
- return null;
- };
- const weights = new keras.Weights();
- const group = context.value;
- const root_group = find_root_group(group);
- const model_config = read_model_config(root_group);
- if (model_config) {
- const backend = root_group.attributes.get('backend') || '';
- const version = root_group.attributes.get('keras_version') || '';
- const format = `Keras${version ? ` v${version}` : ''}`;
- const model_weights_group = root_group.group('model_weights');
- if (model_weights_group) {
- const layer_names = load_attributes_from_hdf5_group(model_weights_group, 'layer_names');
- for (const layer_name of layer_names) {
- const layer_weights = model_weights_group.group(layer_name);
- if (layer_weights) {
- const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
- if (Array.isArray(weight_names) && weight_names.length > 0) {
- for (const weight_name of weight_names) {
- const weight = layer_weights.group(weight_name);
- if (weight && weight.value) {
- const variable = weight.value;
- const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, null, variable.littleEndian ? '<' : '>', variable.data);
- weights.add(layer_name, tensor);
- }
- }
- }
- }
- }
- }
- if (!model_config) {
- throw new keras.Error("'model_config' is not present.");
- }
- if (!model_config.class_name) {
- throw new keras.Error("'class_name' is not present.");
- }
- return open_model(format, '', backend, model_config, weights);
- }
- const layer_names = load_attributes_from_hdf5_group(root_group, 'layer_names');
- if (layer_names && Array.isArray(layer_names)) {
- const version = root_group.attributes.get('keras_version') || '';
- const format = `Keras Weights${version ? ` v${version}` : ''}`;
- const backend = root_group.attributes.get('backend') || '';
- for (const layer_name of layer_names) {
- const layer_weights = root_group.group(layer_name);
- if (layer_weights) {
- const weight_names = load_attributes_from_hdf5_group(layer_weights, 'weight_names');
- if (Array.isArray(weight_names) && weight_names.length > 0) {
- for (const weight_name of weight_names) {
- const weight = layer_weights.group(weight_name);
- if (weight && weight.value) {
- const variable = weight.value;
- const components = weight_name.split('/');
- components.pop();
- const name = (components.length === 0 || components[0] !== layer_name) ? [layer_name].concat(components).join('/') : components.join('/');
- const encoding = variable.littleEndian ? '<' : '>';
- const tensor = new keras.Tensor(weight_name, variable.shape, variable.type, null, null, encoding, variable.data);
- weights.add(name, tensor);
- }
- }
- }
- }
- }
- return open_model(format, '', backend, null, weights);
- }
- const rootKeys = new Set(root_group.attributes.keys());
- rootKeys.delete('nb_layers');
- if (rootKeys.size > 0 || root_group.value !== null) {
- throw new keras.Error('File format is not HDF5 Weights.');
- }
- const format = 'HDF5 Weights';
- let weights_group = root_group;
- if (root_group.attributes.size === 0 && root_group.value === null && root_group.groups.size === 1) {
- const group = root_group.groups.values().next().value;
- if (group.attributes.size === 0 && group.value === null) {
- weights_group = group;
- }
- }
- const tensorKeys = new Set(['name', 'shape', 'quantization']);
- const groups = Array.from(weights_group.groups.values());
- if (groups.every((group) => group.attributes.size === 0 && group.groups.length === 0 && group.value !== null)) {
- for (const group of groups) {
- const variable = group.value;
- const layout = variable.littleEndian ? '<' : '>';
- const tensor = new keras.Tensor(group.name, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
- weights.add('', tensor);
- }
- return open_model(format, '', '', null, weights);
- }
- if (groups.every((group) => group.value === null && Array.from(group.attributes.keys()).filter((key) => !tensorKeys.has(key)).length === 0 && Array.from(group.groups.values()).every((variable) => Object.keys(variable.attributes).length === 0 && variable.value !== null))) {
- for (const group of groups) {
- const module = group.attributes.has('name') ? group.attributes.get('name') : group.name;
- for (const variableGroup of group.groups.values()) {
- if (variableGroup.attributes.size !== 0 || variableGroup.groups.size !== 0) {
- throw new keras.Error('Variable format is not HDF5 Weights.');
- }
- const variable = variableGroup.value;
- if (!variable) {
- throw new keras.Error('Variable value is not HDF5 Weights.');
- }
- const name = module ? [module, variableGroup.name].join('/') : variableGroup.name;
- const layout = variable.littleEndian ? '<' : '>';
- const tensor = new keras.Tensor(name, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
- weights.add(module, tensor);
- }
- }
- return open_model(format, '', '', null, weights);
- }
- const walk = function(group) {
- if (group.attributes.size === 0 && group.value === null && group.groups.size > 0) {
- for (const subGroup of group.groups.values()) {
- walk(subGroup);
- }
- return;
- }
- const subKeys = new Set(['index', 'need_grad']);
- const attribtues = Array.from(group.attributes.keys());
- const match = attribtues.filter((key) => !subKeys.has(key)).length === 0;
- if (match && group.value !== null && group.groups.size === 0) {
- const variable = group.value;
- const variableName = group.path;
- let moduleName = variableName;
- const parts = variableName.split('/');
- if (parts.length > 1) {
- parts.pop();
- moduleName = parts.join('/');
- }
- const layout = variable.littleEndian ? '<' : '>';
- const tensor = new keras.Tensor(variableName, variable.shape, variable.type, null, null, layout, variable.type === 'string' ? variable.value : variable.data);
- weights.add(moduleName, tensor);
- return;
- }
- throw new keras.Error('Module group format is not HDF5 Weights.');
- };
- walk(weights_group);
- return open_model(format, '', '', null, weights);
- }
- case 'tfjs': {
- const target = context.value;
- await target.read();
- return open_model(target.format, target.producer, target.backend, target.config, target.weights);
- }
- case 'keras.pickle': {
- const obj = context.value;
- const execution = new python.Execution();
- const decoder = new TextDecoder('utf-8');
- const format = `Keras Pickle${obj.keras_version ? ` v${decoder.decode(obj.keras_version)}` : ''}`;
- const backend = obj.backend ? decoder.decode(obj.backend) : '';
- const reader = json.TextReader.open(obj.model_config);
- const model_config = reader.read();
- const weights = new keras.Weights();
- const model_weights_group = obj.model_weights;
- if (model_weights_group) {
- const layer_names = model_weights_group.layer_names.map((buffer) => decoder.decode(buffer));
- for (const layer_name of layer_names) {
- const layer_weights = model_weights_group[layer_name];
- if (layer_weights) {
- const weight_names = layer_weights.weight_names.map((buffer) => decoder.decode(buffer));
- if (Array.isArray(weight_names) && weight_names.length > 0) {
- for (const weight_name of weight_names) {
- const buffer = layer_weights[weight_name];
- const pickle = execution.__import__('pickle');
- const unpickler = new pickle.Unpickler(buffer);
- const variable = unpickler.load();
- const tensor = new keras.Tensor(weight_name, variable.shape, variable.dtype.__name__, null, null, '<', variable.data);
- weights.add(layer_name, tensor);
- }
- }
- }
- }
- }
- return open_model(format, '', backend, model_config, weights);
- }
- case 'keras.pb.SavedMetadata': {
- keras.proto = await context.require('./keras-proto');
- const format = 'Keras Saved Metadata';
- const reader = await context.read('protobuf.binary');
- const saved_metadata = keras.proto.third_party.tensorflow.python.keras.protobuf.SavedMetadata.decode(reader);
- if (!saved_metadata || !Array.isArray(saved_metadata.nodes) ||
- !saved_metadata.nodes.every((node) => node && typeof node.metadata === 'string' && node.metadata.length > 0)) {
- throw new keras.Error('Invalid keras.protobuf.SavedMetadata.');
- }
- const objects = new Map();
- for (const node of saved_metadata.nodes) {
- const reader = json.TextReader.open(node.metadata);
- node.metadata = reader.read();
- objects.set(node.node_path, node);
- }
- const model_config = objects.get('root').metadata;
- return open_model(format, '', '', model_config, null);
- }
- default: {
- throw new keras.Error(`Unsupported Keras format '${context.type}'.`);
- }
- }
- }
- };
- keras.Model = class {
- constructor(metadata, format, producer, backend, config, weights) {
- this.format = format;
- this.runtime = backend;
- this.producer = producer;
- metadata = new keras.GraphMetadata(metadata);
- this.modules = [new keras.Graph(metadata, config, weights)];
- }
- };
- keras.Graph = class {
- constructor(metadata, config, weights, group) {
- this.inputs = [];
- this.outputs = [];
- this.nodes = [];
- group = group || '';
- const values = new Map();
- values.map = (name, type, tensor) => {
- if (tensor) {
- return new keras.Value(name, type || null, tensor);
- }
- if (!values.has(name)) {
- values.set(name, new keras.Value(name, type || null, tensor || null));
- } else if (type || tensor) {
- throw new keras.Error(`Duplicate value '${name}'.`);
- }
- return values.get(name);
- };
- if (config) {
- const getInputType = (layer) => {
- if (layer && layer.config) {
- let dataType = '?';
- let shape = [];
- const config = layer.config;
- if (config.dtype) {
- dataType = config.dtype;
- delete config.dtype;
- }
- if (Array.isArray(config.batch_input_shape)) {
- shape = config.batch_input_shape.map((s) => s === null ? '?' : s);
- delete config.batch_input_shape;
- } else if (config.batch_input_shape &&
- config.batch_input_shape.class_name === '__tuple__' &&
- Array.isArray(config.batch_input_shape.items)) {
- shape = config.batch_input_shape.items.map((s) => s === null ? '?' : s);
- delete config.batch_input_shape;
- }
- return new keras.TensorType(dataType, new keras.TensorShape(shape));
- }
- return null;
- };
- this.name = config.name || (config.config && config.config.name ? config.config.name : '');
- this.description = config.class_name;
- let baseType = config.class_name;
- switch (baseType) {
- case '__Function__':
- this.type = 'function';
- break;
- case 'Sequential':
- case 'Functional':
- case 'Model': {
- break;
- }
- case 'Tokenizer': {
- config = { config: { layers: [config] } };
- baseType = 'Functional';
- break;
- }
- default: {
- const layers = Array.from(config.layers ? config.layers : config);
- const sequential = layers.every((layer) => layer.inbound_nodes === undefined);
- baseType = sequential ? 'Sequential' : 'Functional';
- break;
- }
- }
- switch (baseType) {
- case 'Sequential': {
- config = config.config;
- const outputs = null;
- let name = 'input';
- let index = -1;
- const layers = Array.from(config.layers ? config.layers : config);
- while (layers.length > 0) {
- const layer = layers.shift();
- let current = index.toString();
- index++;
- if (index === 0) {
- const type = getInputType(layer);
- let remove = false;
- if (layer.class_name === 'InputLayer' && layer.config && layer.config.name) {
- name = layer.config.name;
- remove = true;
- }
- const value = values.map(name, type);
- const argument = new keras.Argument(name, [value]);
- this.inputs.push(argument);
- if (remove) {
- continue;
- }
- }
- const nodeInputs = [{ name }];
- if (layer.config && layer.config.name) {
- current = layer.config.name;
- }
- name = current;
- let nodeOutputs = [name];
- if (index === layers.length) {
- if (outputs && outputs.length > 0) {
- nodeOutputs = [outputs[0]];
- name = null;
- }
- }
- layer.inputs = nodeInputs;
- layer.outputs = nodeOutputs;
- const node = new keras.Node(metadata, layer, group, weights, values);
- this.nodes.push(node);
- }
- if (name) {
- const value = values.map(name);
- const argument = new keras.Argument(name, [value]);
- this.outputs.push(argument);
- }
- break;
- }
- case '__Function__':
- case 'Functional':
- case 'Model': {
- config = config.config;
- const nodes = new Map();
- if (config.layers) {
- const is_constant = (item) => {
- return Array.isArray(item) && (item.length === 3 || item.length === 4) && item[0] === '_CONSTANT_VALUE' && item[1] === -1;
- };
- const is_connection = (item) => {
- return Array.isArray(item) && (item.length === 3 || item.length === 4) && typeof item[0] === 'string' && typeof item[1] === 'number' && typeof item[2] === 'number';
- };
- const read_value = (input_data) => {
- if (!Array.isArray(input_data)) {
- return input_data;
- }
- const transform = (value) => {
- if (value.every((item) => is_constant(item))) {
- for (let i = 0; i < value.length; i++) {
- value[i] = value[i][2];
- }
- } else if (value.every((item) => Array.isArray(item))) {
- const dims = value.map((item) => transform(item));
- const [dim] = dims;
- for (let i = 1; i < dims.length; i++) {
- if (dim.length === dims[i].length) {
- if (!dims[i].every((value, i) => value === dim[i])) {
- throw new python.Error('Invalid array shape.');
- }
- }
- }
- return [value.length].concat(dim);
- }
- return [value.length];
- };
- const shape = transform(input_data);
- const flatten = (input) => input.reduce((a, b) => a.concat(Array.isArray(b) ? flatten(b) : b), []);
- const value = flatten(input_data);
- return { shape, value };
- };
- const functional = config.layers.every((layer) => Array.isArray(layer.inbound_nodes));
- if (functional) {
- const read_connection = (input_data) => {
- const [node_name, node_index, tensor_index] = input_data;
- const inbound_node_key = `${node_name}[${node_index}]`;
- const inbound_node = nodes.get(inbound_node_key);
- const tensor_key = `${node_name}[${node_index}][${tensor_index}]`;
- if (inbound_node) {
- while (tensor_index >= inbound_node.outputs.length) {
- inbound_node.outputs.push(undefined);
- }
- inbound_node.outputs[tensor_index] = tensor_key;
- }
- return tensor_key;
- };
- const process_node = (node, inbound_node) => {
- if (Array.isArray(inbound_node) && inbound_node.length === 4 && typeof inbound_node[0] === 'string') {
- const key = read_connection(inbound_node);
- node.inputs.push({ name: key });
- for (const [name, value] of Object.entries(inbound_node[3])) {
- if (is_connection(value)) {
- const key = read_connection(value);
- node.inputs.push({ name: key });
- } else if (Array.isArray(value)) {
- const array = read_value(value);
- node.args[name] = array;
- } else {
- node.args[name] = value;
- }
- }
- } else if (Array.isArray(inbound_node)) {
- for (const input_data of inbound_node) {
- // [ 'conv2d', 0, 0 ] or [ 'conv2d', 0, 0, {} ]
- if (Array.isArray(input_data) && is_connection(input_data)) {
- const key = read_connection(input_data);
- node.inputs.push({ name: key });
- } else if (Array.isArray(input_data) && input_data.every((item) => is_connection(item))) {
- for (const input of input_data) {
- const key = read_connection(input);
- node.inputs.push({ name: key });
- }
- } else if (Array.isArray(input_data)) {
- const value = read_value(input_data);
- node.inputs.push(value);
- } else {
- throw new keras.Error(`Invalid inbound connection '${JSON.stringify(input_data)}'.`);
- }
- }
- } else if (inbound_node && inbound_node.args) {
- for (const arg of inbound_node.args) {
- if (arg && arg.class_name === '__keras_tensor__' && arg.config && is_connection(arg.config.keras_history)) {
- const key = read_connection(arg.config.keras_history);
- node.inputs.push({ name: key });
- } else if (Array.isArray(arg) && arg.every((arg) => arg && arg.class_name === '__keras_tensor__' && arg.config && is_connection(arg.config.keras_history))) {
- for (const input of arg) {
- const key = read_connection(input.config.keras_history);
- node.inputs.push({ name: key });
- }
- }
- }
- }
- };
- let legacy_format = true;
- for (const layer of config.layers) {
- if (Array.isArray(layer.inbound_nodes)) {
- for (const inbound_node of layer.inbound_nodes) {
- if (Array.isArray(inbound_node.args)) {
- legacy_format = false;
- }
- }
- }
- }
- for (const layer of config.layers) {
- const class_name = layer.class_name;
- let first_index = 0;
- if (legacy_format) {
- const keys = new Set(Object.keys(layer.config));
- const is_functional_config = keys.has('name') && keys.has('layers') && keys.has('input_layers') && keys.has('output_layers');
- if (class_name === 'Sequential' ||
- (is_functional_config && Array.isArray(layer.config.layers) && layer.config.layers.length > 0 && layer.config.layers[0].class_name === 'InputLayer')) {
- first_index++;
- }
- }
- if (Array.isArray(layer.inbound_nodes) && layer.inbound_nodes.length === 0) {
- layer.inputs = [];
- layer.outputs = [];
- layer.args = {};
- nodes.set(`${layer.name}[${first_index}]`, layer);
- } else if (Array.isArray(layer.inbound_nodes) && layer.inbound_nodes.length === 1) {
- layer.inputs = [];
- layer.outputs = [];
- layer.args = {};
- [layer.inbound_node] = layer.inbound_nodes;
- nodes.set(`${layer.name}[${first_index}]`, layer);
- } else {
- let config = {};
- switch (class_name) {
- case 'Functional':
- case 'Sequential':
- case 'Model': {
- config = layer;
- break;
- }
- default: {
- config.class_name = '__Function__';
- config.name = layer.name;
- config.config = {};
- config.config.layers = [{ ...layer }];
- delete config.config.layers[0].inbound_nodes;
- delete config.config.layers[0].input_layers;
- delete config.config.layers[0].output_layers;
- break;
- }
- }
- const type = new keras.Graph(metadata, config, weights, '');
- for (let i = 0; i < layer.inbound_nodes.length; i++) {
- const index = i + first_index;
- const key = `${layer.name}[${index}]`;
- const node = {};
- node.name = key;
- node.class_name = '__Function__';
- node.config = {};
- node.config.name = key;
- node.inputs = [];
- node.outputs = [];
- node.args = {};
- node.__type__ = type;
- node.inbound_node = layer.inbound_nodes[i];
- nodes.set(key, node);
- }
- }
- }
- for (const entry of nodes) {
- if (entry[1].inbound_node) {
- process_node(entry[1], entry[1].inbound_node);
- }
- }
- if (Array.isArray(config.input_layers)) {
- if (config.input_layers.length === 3 && typeof config.input_layers[0] === 'string' && Number.isInteger(config.input_layers[1]) && Number.isInteger(config.input_layers[2])) {
- config.input_layers = [config.input_layers];
- }
- for (let i = 0; i < config.input_layers.length; i++) {
- const input_data = config.input_layers[i];
- const name = read_connection(input_data);
- const [node_name, node_index] = input_data;
- const inbound_node_key = `${node_name}[${node_index}]`;
- const node = nodes.get(inbound_node_key);
- let type = null;
- if (node && node.class_name === 'InputLayer') {
- type = getInputType(node);
- nodes.delete(name);
- nodes.delete(inbound_node_key);
- }
- const value = values.map(name, type);
- const argument = new keras.Argument(node_name, [value]);
- this.inputs.push(argument);
- }
- }
- if (Array.isArray(config.output_layers)) {
- if (config.output_layers.length === 3 && typeof config.output_layers[0] === 'string' && Number.isInteger(config.output_layers[1]) && Number.isInteger(config.output_layers[2])) {
- config.output_layers = [config.output_layers];
- }
- for (let i = 0; i < config.output_layers.length; i++) {
- const output_data = config.output_layers[i];
- const [name] = output_data;
- const key = read_connection(output_data);
- const value = values.map(key);
- const argument = new keras.Argument(name, [value]);
- this.outputs.push(argument);
- }
- }
- } else {
- for (const layer of config.layers) {
- layer.inputs = [];
- layer.outputs = [];
- layer.args = {};
- nodes.set(`${layer.name}[0]`, layer);
- }
- }
- }
- for (const entry of nodes) {
- const node = new keras.Node(metadata, entry[1], group, weights, values);
- this.nodes.push(node);
- }
- break;
- }
- default: {
- throw new keras.Error(`'${config.class_name}' is not supported.`);
- }
- }
- } else if (weights) {
- this.type = 'weights';
- for (const name of weights.keys()) {
- if (weights.get('', name).length <= 6) {
- const layer = { class_name: 'Weights', config: { name } };
- const node = new keras.Node(metadata, layer, '', weights, values);
- this.nodes.push(node);
- }
- }
- }
- }
- };
- keras.Argument = class {
- constructor(name, value, type = null, visible = true) {
- this.name = name;
- this.value = value;
- this.type = type;
- this.visible = visible;
- }
- };
- keras.Value = class {
- constructor(name, type, initializer = null) {
- if (typeof name !== 'string') {
- throw new keras.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
- }
- this.name = name;
- this.type = !type && initializer ? initializer.type : type;
- this.quantization = initializer && initializer.quantization ? initializer.quantization : null;
- this.initializer = initializer;
- }
- };
- keras.Node = class {
- constructor(metadata, layer, group, weights, values) {
- const config = layer.config || {};
- const args = layer.args || {};
- let inputs = layer.inputs || [];
- let outputs = layer.outputs || [];
- const name = config && config.name ? config.name : '';
- this.group = group || '';
- this.name = (this.group ? `${this.group}/` : '') + name;
- this.inputs = [];
- this.outputs = [];
- this.attributes = [];
- this.chain = [];
- let names = [name];
- let class_name = layer.class_name;
- let model = false;
- switch (class_name) {
- case '__Function__': {
- this.type = layer.__type__;
- model = true;
- break;
- }
- case 'Model':
- case 'Functional':
- case 'Sequential': {
- const name = layer.name || (layer.config ? layer.config.name : '');
- this.type = new keras.Graph(metadata, layer, weights, (group ? `${group}/` : '') + name);
- model = true;
- if (config) {
- delete config.layers;
- delete config.input_layers;
- delete config.output_layers;
- }
- this.inputs = [new keras.Argument('inputs', inputs.map((input) => values.map(input.name)))];
- this.outputs = [new keras.Argument('outputs', outputs.map((name) => values.map(name)))];
- inputs = [];
- outputs = [];
- break;
- }
- case 'Wrapper':
- case 'Bidirectional':
- case 'TimeDistributed': {
- if (config && config.layer) {
- const inner = config.layer;
- delete config.layer;
- this.inner = new keras.Node(metadata, inner, null, null, values);
- if (class_name === 'Bidirectional' && inner.config.name) {
- names = [`${name}/forward_${inner.config.name}`, `${name}/backward_${inner.config.name}`];
- if (!group) {
- group = name;
- }
- }
- }
- this.type = metadata.type(class_name) || { name: class_name };
- break;
- }
- case 'TFOpLambda': {
- if (config && config.function) {
- class_name = config.function;
- delete config.function;
- }
- this.type = metadata.type(class_name) || { name: class_name };
- break;
- }
- default: {
- this.type = metadata.type(class_name) || { name: class_name };
- break;
- }
- }
- if (layer._trainable_variables) {
- if (inputs.length === 0 && Array.isArray(this.type.inputs) && this.type.inputs.length > 0) {
- // weights-only, remove 'input' from type metadata
- this.type = { ...this.type };
- this.type.inputs = this.type.inputs.slice(1);
- }
- for (const variable of layer._trainable_variables) {
- inputs.push({ name: '', initializer: variable });
- }
- } else if (weights && !model) {
- for (const name of names) {
- let tensors = weights.get(group, name);
- if (tensors.length > 0) {
- for (const initializer of tensors) {
- inputs.push({ name: initializer.name, initializer });
- }
- } else {
- tensors = weights.get('', name);
- for (const initializer of tensors) {
- inputs.push({ name: initializer.name, initializer });
- }
- }
- }
- }
- const attributes = [];
- const convertAttributeValue = (value) => {
- if (Array.isArray(value) || value !== Object(value)) {
- return value;
- }
- const obj = {};
- if (value.class_name) {
- obj.__type__ = value.class_name;
- }
- if (value.config) {
- const config = value.config;
- for (const [key, value] of Object.entries(config)) {
- obj[key] = convertAttributeValue(value);
- }
- }
- return obj;
- };
- if (config && !Array.isArray(config)) {
- for (const [name, value] of Object.entries(config)) {
- if (class_name !== 'Activation' && name === 'activation' && value !== 'linear') {
- if (typeof value === 'string') {
- const config = { activation: value };
- const node = new keras.Node(metadata, { class_name: 'Activation', config }, null, null, value);
- this.chain.push(node);
- } else if (value && typeof value.class_name === 'string' && value.config) {
- const type = value.class_name;
- if (!metadata.type(type)) {
- metadata.add(type, { name: type, category: 'Activation' });
- }
- const node = new keras.Node(metadata, value, null, null, value);
- this.chain.push(node);
- }
- }
- if (name !== 'name' && name !== 'batch_input_shape') {
- const schema = metadata.attribute(class_name, name);
- attributes.push([schema, name, value]);
- }
- }
- }
- const innerType = this.inner ? this.inner.type : null;
- const innerMetadata = innerType ? metadata.type(innerType) : null;
- let inputIndex = 0;
- while (inputs.length > 0) {
- let list = false;
- let name = null;
- let visible = true;
- if (!innerMetadata || inputIndex === 0) {
- if (this.type && this.type.inputs && inputIndex < this.type.inputs.length) {
- const input = this.type.inputs[inputIndex];
- name = input.name;
- if (class_name === 'BatchNormalization' && name === 'gamma' && config.scale === false) {
- inputIndex++;
- continue;
- }
- visible = input.visible !== false;
- if (this.type.inputs[inputIndex].list) {
- list = true;
- }
- }
- } else {
- switch (class_name) {
- case 'Bidirectional': {
- let innerIndex = inputIndex;
- if (innerMetadata && innerMetadata.inputs) {
- if (innerIndex < innerMetadata.inputs.length) {
- name = `forward_${innerMetadata.inputs[innerIndex].name}`;
- } else {
- innerIndex = innerIndex - innerMetadata.inputs.length + 1;
- if (innerIndex < innerMetadata.inputs.length) {
- name = `backward_${innerMetadata.inputs[innerIndex].name}`;
- }
- }
- }
- visible = false;
- break;
- }
- case 'TimeDistributed':
- if (innerMetadata && innerMetadata.inputs && inputIndex < innerMetadata.inputs.length) {
- name = innerMetadata.inputs[inputIndex].name;
- }
- break;
- default:
- break;
- }
- }
- const input = list ? inputs.splice(0, inputs.length) : [inputs.shift()];
- const inputArguments = input.map((input) => {
- if (input.name) {
- return values.map(input.name, null, input.initializer);
- }
- if (input.initializer) {
- return values.map(input.name, null, input.initializer);
- }
- if (input.value !== undefined) {
- const tensor = new keras.Tensor('', input.shape, config.dtype || '?', null, null, '|', input.value);
- return values.map('', null, tensor);
- }
- throw new keras.Error(`Invalid argument '${JSON.stringify(input.name)}'.`);
- });
- if (!name && inputArguments.length === 1 && inputArguments[0].initializer && inputArguments[0].initializer.name) {
- if (names.length === 1 && names[0] === '') {
- name = inputArguments[0].initializer.name;
- } else {
- const parts = inputArguments[0].initializer.name.split('/').pop().split(':').shift().split('_');
- const inputName1 = parts.pop();
- const inputName2 = parts.length > 0 ? [parts.pop(), inputName1].join('_') : '';
- const inputNames = new Set(['recurrent_kernel', 'running_mean', 'running_std', 'moving_mean', 'moving_variance', 'depthwise_filter', 'pointwise_filter']);
- name = inputNames.has(inputName2) ? inputName2 : inputName1;
- }
- }
- const argument = new keras.Argument(name || inputIndex.toString(), inputArguments, null, visible);
- this.inputs.push(argument);
- inputIndex++;
- }
- for (let i = 0; i < outputs.length; i++) {
- const output = outputs[i];
- const name = this.type && this.type.outputs && i < this.type.outputs.length && this.type.outputs[i] && this.type.outputs[i].name ? this.type.outputs[i].name : i.toString();
- const argument = new keras.Argument(name, output === undefined || output.length === 0 ? [] : [values.map(output)]);
- this.outputs.push(argument);
- }
- const inputTypes = new Map((this.type.inputs || []).map((input) => [input.name, input.type]));
- for (const [name, arg] of Object.entries(args)) {
- if (name !== 'name') {
- if ((arg && arg.name) || (inputTypes.has(name) && inputTypes.get(name) === 'Tensor' && arg)) {
- if (arg.name) {
- const value = values.map(arg.name);
- const argument = new keras.Argument(name, [value]);
- this.inputs.push(argument);
- } else {
- const tensor = new keras.Tensor('', arg.shape, config.dtype || '?', null, null, '|', arg.value);
- const value = values.map('', null, tensor);
- const argument = new keras.Argument(name, [value]);
- this.inputs.push(argument);
- }
- } else {
- const schema = metadata.attribute(class_name, name);
- this.attributes.push([schema, name, arg]);
- }
- }
- }
- this.attributes = attributes.map(([metadata, name, value]) => {
- let type = null;
- let visible = true;
- if (value && typeof value === 'object' && value.class_name && value.config) {
- value = convertAttributeValue(value);
- }
- switch (name) {
- case 'trainable':
- type = 'boolean';
- visible = false;
- break;
- case 'dtype':
- visible = false;
- break;
- default: {
- if (metadata) {
- type = metadata.type ? metadata.type : type;
- if (metadata.visible === false) {
- visible = false;
- } else if (metadata.default !== undefined) {
- if (Array.isArray(value)) {
- if (Array.isArray(metadata.default)) {
- visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]);
- } else {
- visible = !value.every((item) => item === metadata.default);
- }
- } else {
- visible = value !== metadata.default;
- }
- }
- }
- break;
- }
- }
- return new keras.Argument(name, value, type, visible);
- });
- if (typeof this.type.name !== 'string' || !this.type.name.split) { // #416
- throw new keras.Error(`Unsupported node type '${JSON.stringify(this.type.name)}'.`);
- }
- }
- };
- keras.Tensor = class {
- constructor(name, shape, type, stride, quantization, encoding, data, location) {
- this.name = name;
- this.type = new keras.TensorType(type, new keras.TensorShape(shape));
- this.stride = stride;
- this.encoding = encoding;
- this._data = data;
- this.location = location;
- if (quantization && (quantization.scale !== 0 || quantization.min !== 0)) {
- this.quantization = {
- type: 'linear',
- scale: [quantization.scale],
- min: [quantization.min]
- };
- }
- }
- get values() {
- if (this.encoding === '|') {
- return this._data;
- }
- if (this._data === null) {
- return null;
- }
- return this._data instanceof Uint8Array ? this._data : this._data.peek();
- }
- };
- keras.TensorType = class {
- constructor(dataType, shape) {
- this.dataType = dataType;
- this.shape = shape;
- }
- toString() {
- return this.dataType + this.shape.toString();
- }
- };
- keras.TensorShape = class {
- constructor(dimensions) {
- this.dimensions = dimensions;
- }
- toString() {
- return this.dimensions && this.dimensions.length > 0 ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
- }
- };
- keras.GraphMetadata = class {
- constructor(metadata) {
- this._metadata = metadata;
- this._types = new Map();
- }
- type(name) {
- if (this._types.has(name)) {
- return this._types.get(name);
- }
- return this._metadata.type(name);
- }
- attribute(type, name) {
- return this._metadata.attribute(type, name);
- }
- add(type, metadata) {
- this._types.set(type, metadata);
- }
- };
- keras.Weights = class {
- constructor() {
- this._map = new Map();
- }
- get empty() {
- return this._map.size === 0;
- }
- add(layer_name, tensor) {
- if (!this._map.has(layer_name)) {
- this._map.set(layer_name, []);
- }
- this._map.get(layer_name).push(tensor);
- }
- get(group, name) {
- if (group) {
- const list = this._map.get(group.split('/').shift());
- if (list) {
- const match1 = list.filter((tensor) => tensor.name.startsWith(`${name}/`));
- if (match1.length > 0) {
- return match1;
- }
- const match2 = list.filter((tensor) => tensor.name.startsWith(`${group}/${name}/`));
- if (match2.length > 0) {
- return match2;
- }
- }
- } else {
- const match1 = this._map.get(name);
- if (match1 && match1.length > 0) {
- return match1;
- }
- const match2 = this._map.get('');
- if (match2 && match2.length > 0) {
- const match3 = match2.filter((tensor) => tensor.name.startsWith(`${(group ? `${group}/` : '') + name}/`));
- if (match3.length > 0) {
- return match3;
- }
- }
- }
- return [];
- }
- keys() {
- return this._map.keys();
- }
- };
- keras.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading Keras model.';
- }
- };
- tfjs.Container = class {
- static async open(context) {
- const json = await context.peek('json');
- if (json) {
- if (json.modelTopology && (json.format === 'layers-model' || json.modelTopology.class_name || json.modelTopology.model_config)) {
- return new tfjs.Container(context, '');
- }
- if (Array.isArray(json) && json.every((item) => item.weights && item.paths)) {
- return new tfjs.Container(context, 'weights.json');
- }
- if (json.tfjsVersion) {
- return new tfjs.Container(context, 'metadata');
- }
- }
- const identifier = context.identifier;
- if (/^.*group\d+-shard\d+of\d+(\.bin)?$/.test(identifier)) {
- return new tfjs.Container(context, 'weights.bin');
- }
- return null;
- }
- constructor(context, type) {
- this.context = context;
- this.type = type;
- }
- async read() {
- const context = this.context;
- switch (this.type) {
- case '': {
- const obj = await context.peek('json');
- return this._openModelJson(obj);
- }
- case 'weights.json': {
- this.format = 'TensorFlow.js Weights';
- this.config = null;
- const obj = await context.peek('json');
- const manifests = Array.from(obj);
- for (const manifest of manifests) {
- for (const weight of manifest.weights) {
- const name = weight.name;
- const index = name.lastIndexOf('/');
- weight.identifier = index === -1 ? name : name.substring(0, index);
- }
- }
- return this._openManifests(manifests);
- }
- case 'weights.bin': {
- const content = await this.context.fetch('model.json');
- const obj = await content.read('json');
- return this._openModelJson(obj);
- }
- case 'metadata': {
- const content = await this.context.fetch('model.json');
- const obj = await content.read('json');
- return this._openModelJson(obj);
- }
- default: {
- throw new tfjs.Error(`Unsupported TensorFlow.js format '${this.type}'.`);
- }
- }
- }
- _openShards(manifests, shards) {
- this.weights = new keras.Weights();
- const dtype_size_map = new Map([
- ['float16', 2], ['float32', 4], ['float64', 8],
- ['int8', 1], ['int16', 2], ['int32', 4], ['int64', 8],
- ['uint8', 1], ['uint16', 2], ['uint32', 4], ['uint64', 8]
- ]);
- for (const manifest of manifests) {
- let buffer = null;
- let location = '';
- if (Array.isArray(manifest.paths) && manifest.paths.length > 0 && manifest.paths.every((path) => shards.has(path))) {
- const list = manifest.paths.map((path) => shards.get(path));
- location = manifest.paths.join(', ');
- const size = list.reduce((a, b) => a + b.length, 0);
- buffer = new Uint8Array(size);
- let offset = 0;
- for (const item of list) {
- buffer.set(item, offset);
- offset += item.length;
- }
- }
- let offset = 0;
- for (const weight of manifest.weights) {
- const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
- if (!dtype_size_map.has(dtype)) {
- throw new keras.Error(`Unsupported weight data type size '${dtype}'.`);
- }
- const itemsize = dtype_size_map.get(dtype);
- const size = weight.shape.reduce((a, b) => a * b, 1);
- const length = itemsize * size;
- const data = buffer ? buffer.slice(offset, offset + length) : null;
- const tensor = new keras.Tensor(weight.name, weight.shape, dtype, null, weight.quantization, '<', data, location);
- this.weights.add(weight.identifier, tensor);
- offset += length;
- }
- }
- }
- async _openManifests(manifests) {
- const shards = new Map();
- for (const manifest of manifests) {
- for (const path of manifest.paths) {
- if (!shards.has(path)) {
- const promise = this.context.fetch(path);
- shards.set(path, promise);
- }
- }
- }
- const promises = shards.values();
- try {
- const contexts = await Promise.all(promises);
- for (const key of shards.keys()) {
- const context = contexts.shift();
- const buffer = context.stream.peek();
- shards.set(key, buffer);
- }
- this._openShards(manifests, shards);
- } catch {
- shards.clear();
- this._openShards(manifests, shards);
- }
- }
- _openModelJson(obj) {
- if (!obj || !obj.modelTopology || (obj.format !== 'layers-model' && !obj.modelTopology.model_config && !obj.modelTopology.config)) {
- throw new tfjs.Error('File format is not TensorFlow.js layers-model.');
- }
- const modelTopology = obj.modelTopology;
- if (obj.format) {
- this.format = `TensorFlow.js ${obj.format}`;
- } else if (modelTopology.keras_version) {
- const match = modelTopology.keras_version.match(/^(.+)\s+(\d.*)$/);
- const version = match ? `${match[1]} v${match[2]}` : `v${modelTopology.keras_version}`;
- this.format = `TensorFlow.js Keras ${version}`;
- } else {
- this.format = 'TensorFlow.js Keras';
- }
- this.producer = obj.convertedBy || obj.generatedBy || '';
- this.backend = modelTopology.backend || '';
- const manifests = obj.weightsManifest;
- for (const manifest of manifests) {
- for (const weight of manifest.weights) {
- weight.identifier = '';
- }
- }
- this.config = modelTopology.model_config ? modelTopology.model_config : modelTopology;
- return this._openManifests(manifests);
- }
- };
- tfjs.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading TensorFlow.js model.';
- }
- };
- export const ModelFactory = keras.ModelFactory;
|