|
|
@@ -132,7 +132,7 @@ keras.Model = class {
|
|
|
this._producer = producer;
|
|
|
this._graphs = [];
|
|
|
|
|
|
- let weights = {};
|
|
|
+ let weights = new keras.Weights();
|
|
|
if (rootGroup) {
|
|
|
let model_weights_group = rootGroup.group('model_weights');
|
|
|
if (!model_weights_group && rootGroup.attribute('layer_names')) {
|
|
|
@@ -140,51 +140,24 @@ keras.Model = class {
|
|
|
}
|
|
|
if (model_weights_group) {
|
|
|
model_weights_group = new keras.Group(model_weights_group);
|
|
|
- let layer_names = model_weights_group.attribute('layer_names');
|
|
|
- let layer_names_map = new Set();
|
|
|
- for (const layer_name of layer_names) {
|
|
|
- layer_names_map.add(layer_name);
|
|
|
- }
|
|
|
- for (const layer_name of layer_names) {
|
|
|
- let layer_weights = model_weights_group.group(layer_name);
|
|
|
+ for (const layer_name of model_weights_group.attribute('layer_names')) {
|
|
|
+ const layer_weights = model_weights_group.group(layer_name);
|
|
|
if (layer_weights) {
|
|
|
- let weight_names = layer_weights.attribute('weight_names');
|
|
|
- if (layer_weights && weight_names && weight_names.length > 0) {
|
|
|
+ const weight_names = layer_weights.attribute('weight_names');
|
|
|
+ if (weight_names && weight_names.length > 0) {
|
|
|
for (let weight_name of weight_names) {
|
|
|
- let group = layer_weights.group(weight_name);
|
|
|
- if (group) {
|
|
|
- let variable = group.value;
|
|
|
- if (variable) {
|
|
|
- if (model_config) {
|
|
|
- let initializer = new keras.Tensor(weight_name, variable.type, variable.shape, variable.littleEndian, variable.data, '');
|
|
|
- let parts = weight_name.split('/');
|
|
|
- parts.pop();
|
|
|
- let match = false;
|
|
|
- while (parts.length > 0) {
|
|
|
- let name = parts.join('/');
|
|
|
- if (layer_names_map.has(name)) {
|
|
|
- match = true;
|
|
|
- }
|
|
|
- weights[name] = weights[name] || [];
|
|
|
- weights[name].push(initializer);
|
|
|
- parts.shift();
|
|
|
- }
|
|
|
- if (!match) {
|
|
|
- weights[layer_name] = weights[layer_name] || [];
|
|
|
- weights[layer_name].push(initializer);
|
|
|
- }
|
|
|
- }
|
|
|
- else {
|
|
|
- if (!weight_name.startsWith(layer_name + '/')) {
|
|
|
- weight_name = layer_name + '/' + weight_name;
|
|
|
- }
|
|
|
- let initializer = new keras.Tensor(weight_name, variable.type, variable.shape, variable.littleEndian, variable.data, '');
|
|
|
- let parts = weight_name.split('/');
|
|
|
- parts.pop();
|
|
|
- let name = parts.join('/');
|
|
|
- weights[name] = weights[name] || [];
|
|
|
- weights[name].push(initializer);
|
|
|
- }
|
|
|
+ const weight = layer_weights.group(weight_name);
|
|
|
+ if (weight && weight.value) {
|
|
|
+ const variable = weight.value;
|
|
|
+ const tensor = new keras.Tensor(weight_name, variable.type, variable.shape, variable.littleEndian, variable.data, '');
|
|
|
+ if (model_config) {
|
|
|
+ weights.add(layer_name, tensor);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ const components = weight_name.split('/');
|
|
|
+ components.pop();
|
|
|
+ const name = (components.length == 0 || components[0] !== layer_name) ? [ layer_name ].concat(components).join('/') : components.join('/');
|
|
|
+ weights.add(name, tensor);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -196,15 +169,8 @@ keras.Model = class {
|
|
|
else if (weightsManifest) {
|
|
|
for (const manifest of weightsManifest) {
|
|
|
for (const weight of manifest.weights) {
|
|
|
- let p = weight.name.split('/');
|
|
|
- p.pop();
|
|
|
- let initializer = new keras.Tensor(weight.name, weight.dtype, weight.shape, false, null, manifest.paths.join(';'));
|
|
|
- while (p.length > 0) {
|
|
|
- let weightName = p.join('/');
|
|
|
- weights[weightName] = weights[weightName] || [];
|
|
|
- weights[weightName].push(initializer);
|
|
|
- p.shift();
|
|
|
- }
|
|
|
+ const tensor = new keras.Tensor(weight.name, weight.dtype, weight.shape, false, null, manifest.paths.join(';'));
|
|
|
+ weights.add('', tensor);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -262,9 +228,9 @@ keras.Graph = class {
|
|
|
}
|
|
|
}
|
|
|
else if (weights) {
|
|
|
- for (const layer of Object.keys(weights)) {
|
|
|
- if (weights[layer].length <= 6) {
|
|
|
- const node = new keras.Node(metadata, 'Weights', { name: layer }, [], [], false, weights);
|
|
|
+ for (const layer of weights.keys()) {
|
|
|
+ if (weights.get('', layer).length <= 6) {
|
|
|
+ const node = new keras.Node(metadata, 'Weights', { name: layer }, [], [], '', weights);
|
|
|
this._nodes.push(node)
|
|
|
}
|
|
|
}
|
|
|
@@ -401,8 +367,7 @@ keras.Graph = class {
|
|
|
let inputType = null;
|
|
|
let argument = inputName;
|
|
|
let index = 0;
|
|
|
- let layers = config.layers ? config.layers : config;
|
|
|
-
|
|
|
+ const layers = config.layers ? config.layers : config;
|
|
|
for (const layer of layers) {
|
|
|
let name = index.toString();
|
|
|
let nodeInputs = [ argument ];
|
|
|
@@ -438,14 +403,16 @@ keras.Graph = class {
|
|
|
}
|
|
|
|
|
|
_loadNode(layer, inputs, outputs, weights, group, inputMap) {
|
|
|
- let class_name = layer.class_name;
|
|
|
+ const class_name = layer.class_name;
|
|
|
switch (class_name) {
|
|
|
case 'Sequential': {
|
|
|
- this._loadSequential(layer.config, weights, layer.name, inputs, outputs);
|
|
|
+ const name = layer.name || (layer.config ? layer.config.name : '')
|
|
|
+ this._loadSequential(layer.config, weights, (group ? group + '/' : '') + name, inputs, outputs);
|
|
|
break;
|
|
|
}
|
|
|
case 'Model': {
|
|
|
- this._loadModel(layer.config, weights, layer.name, inputs, outputs);
|
|
|
+ const name = layer.name || (layer.config ? layer.config.name : '')
|
|
|
+ this._loadModel(layer.config, weights, (group ? group + '/' : '') + name, inputs, outputs);
|
|
|
break;
|
|
|
}
|
|
|
default: {
|
|
|
@@ -524,43 +491,43 @@ keras.Argument = class {
|
|
|
keras.Node = class {
|
|
|
|
|
|
constructor(metadata, operator, config, inputs, outputs, group, weights) {
|
|
|
- if (group) {
|
|
|
- this._group = group;
|
|
|
- }
|
|
|
+ this._group = group || '';
|
|
|
this._metadata = metadata;
|
|
|
this._operator = operator;
|
|
|
- this._name = (config && config.name) ? config.name : '';
|
|
|
+ const name = config && config.name ? config.name : '';
|
|
|
+ this._name = (this._group ? this._group + '/' : '') + name;
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
this._attributes = [];
|
|
|
|
|
|
- let names = [ this._name ];
|
|
|
+ let names = [ name ];
|
|
|
if ((operator == 'Bidirectional' || operator == 'TimeDistributed') && (config && config.layer)) {
|
|
|
let inner = config.layer;
|
|
|
delete config.layer;
|
|
|
this._inner = new keras.Node(this._metadata, inner.class_name, inner.config, [], [], null, null);
|
|
|
if (operator == 'Bidirectional' && inner.config.name) {
|
|
|
- names = [ this._name + '/forward_' + inner.config.name, this._name + '/backward_' + inner.config.name ];
|
|
|
+ names = [ name + '/forward_' + inner.config.name, name + '/backward_' + inner.config.name ];
|
|
|
+ if (!group) {
|
|
|
+ group = name;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
let initializers = {};
|
|
|
if (weights) {
|
|
|
for (const name of names) {
|
|
|
- if (weights[name]) {
|
|
|
- for (const initializer of weights[name]) {
|
|
|
- inputs.push(initializer.name);
|
|
|
- initializers[initializer.name] = initializer;
|
|
|
- }
|
|
|
+ for (const initializer of weights.get(group, name)) {
|
|
|
+ inputs.push(initializer.name);
|
|
|
+ initializers[initializer.name] = initializer;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if (config) {
|
|
|
- for (const attributeName of Object.keys(config)) {
|
|
|
- const attributeValue = config[attributeName];
|
|
|
- if (attributeName != 'name' && attributeValue != null) {
|
|
|
- this._attributes.push(new keras.Attribute(this._metadata, this.operator, attributeName, attributeValue));
|
|
|
+ for (const name of Object.keys(config)) {
|
|
|
+ const value = config[name];
|
|
|
+ if (name != 'name' && value != null) {
|
|
|
+ this._attributes.push(new keras.Attribute(this._metadata, this.operator, name, value));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -612,7 +579,7 @@ keras.Node = class {
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
- const input = !variadic ? [ inputs.shift() ] : inputs.slice(0, inputs.length);
|
|
|
+ const input = !variadic ? [ inputs.shift() ] : inputs.splice(0, inputs.length);
|
|
|
const inputArguments = input.map((id) => {
|
|
|
return new keras.Argument(id, null, initializers[id]);
|
|
|
});
|
|
|
@@ -640,16 +607,16 @@ keras.Node = class {
|
|
|
return this._operator;
|
|
|
}
|
|
|
|
|
|
+ get metadata() {
|
|
|
+ return this._metadata.type(this._operator);
|
|
|
+ }
|
|
|
+
|
|
|
get name() {
|
|
|
return this._name;
|
|
|
}
|
|
|
|
|
|
get group() {
|
|
|
- return this._group ? this._group : '';
|
|
|
- }
|
|
|
-
|
|
|
- get metadata() {
|
|
|
- return this._metadata.type(this._operator);
|
|
|
+ return this._group;
|
|
|
}
|
|
|
|
|
|
get inputs() {
|
|
|
@@ -1278,6 +1245,54 @@ keras.JsonParser = class {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+keras.Weights = class {
|
|
|
+
|
|
|
+ constructor() {
|
|
|
+ this._map = new Map();
|
|
|
+ }
|
|
|
+
|
|
|
+ add(layer_name, tensor) {
|
|
|
+ if (!this._map.has(layer_name)) {
|
|
|
+ this._map.set(layer_name, []);
|
|
|
+ }
|
|
|
+ this._map.get(layer_name).push(tensor);
|
|
|
+ }
|
|
|
+
|
|
|
+ get(group, name) {
|
|
|
+ if (group) {
|
|
|
+ const list = this._map.get(group.split('/').shift());
|
|
|
+ if (list) {
|
|
|
+ const match1 = list.filter((tensor) => tensor.name.startsWith(name + '/'));
|
|
|
+ if (match1.length > 0) {
|
|
|
+ return match1;
|
|
|
+ }
|
|
|
+ const match2 = list.filter((tensor) => tensor.name.startsWith(group + '/' + name + '/'));
|
|
|
+ if (match2.length > 0) {
|
|
|
+ return match2;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ const match1 = this._map.get(name);
|
|
|
+ if (match1 && match1.length > 0) {
|
|
|
+ return match1;
|
|
|
+ }
|
|
|
+ const match2 = this._map.get('');
|
|
|
+ if (match2 && match2.length > 0) {
|
|
|
+ const match3 = match2.filter((tensor) => tensor.name.startsWith((group ? group + '/' : '') + name + '/'));
|
|
|
+ if (match3.length > 0) {
|
|
|
+ return match3;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return [];
|
|
|
+ }
|
|
|
+
|
|
|
+ keys() {
|
|
|
+ return this._map.keys();
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
keras.Error = class extends Error {
|
|
|
|
|
|
constructor(message) {
|