|
|
@@ -704,7 +704,6 @@ tf.Graph = class {
|
|
|
this._outputs = [];
|
|
|
this._nodes = [];
|
|
|
this._version = null;
|
|
|
-
|
|
|
if (meta_graph && meta_graph.graph_def) {
|
|
|
const graph = meta_graph.graph_def;
|
|
|
if (graph.versions) {
|
|
|
@@ -727,8 +726,7 @@ tf.Graph = class {
|
|
|
this._outputs = context.outputs;
|
|
|
}
|
|
|
else if (bundle) {
|
|
|
- const nodeNames = [];
|
|
|
- const nodeMap = new Map();
|
|
|
+ const nodes = new Map();
|
|
|
for (const tensor of bundle.tensors) {
|
|
|
const parts = tensor.name.split('/');
|
|
|
if (bundle.format === 2) {
|
|
|
@@ -745,17 +743,17 @@ tf.Graph = class {
|
|
|
}
|
|
|
}
|
|
|
const tensorName = parts.pop();
|
|
|
- const nodeName = parts.join('/');
|
|
|
- if (!nodeMap.has(nodeName)) {
|
|
|
- nodeNames.push(nodeName);
|
|
|
- nodeMap.set(nodeName, []);
|
|
|
+ const name = parts.join('/');
|
|
|
+ if (!nodes.has(name)) {
|
|
|
+ nodes.set(name, []);
|
|
|
}
|
|
|
- nodeMap.get(nodeName).push({ name: tensorName, value: tensor });
|
|
|
+ nodes.get(name).push({ name: tensorName, value: tensor });
|
|
|
}
|
|
|
const namespaces = new Set();
|
|
|
- for (const name of nodeNames) {
|
|
|
- this._nodes.push(new tf.Node(metadata, namespaces, null, 'Node', name, null, nodeMap.get(name)));
|
|
|
- }
|
|
|
+ this._nodes = Array.from(nodes).map((entry) => {
|
|
|
+ const node = { op: 'Node', name: entry[0] };
|
|
|
+ return new tf.Node(metadata, node, namespaces, null, entry[1]);
|
|
|
+ });
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -924,41 +922,45 @@ tf.Function = class {
|
|
|
|
|
|
tf.Node = class {
|
|
|
|
|
|
- constructor(metadata, namespaces, node, op, name, initializers, tensors) {
|
|
|
- this._type = Object.assign({}, node && node.metadata ? node.metadata : metadata.type(op) || { name: op });
|
|
|
- this._type.identifier = this._type.name;
|
|
|
- this._type.name = op;
|
|
|
- this._name = name;
|
|
|
+ constructor(metadata, node, namespaces, initializers, tensors) {
|
|
|
+ this._type = node.metadata || metadata.type(node.op) || { name: node.op };
|
|
|
+ this._name = node.name;
|
|
|
this._attributes = [];
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
-
|
|
|
this._group = '';
|
|
|
- if (namespaces.has(name)) {
|
|
|
- this._group = name;
|
|
|
- }
|
|
|
- else {
|
|
|
- const lastIndex = name.lastIndexOf('/');
|
|
|
- if (lastIndex != -1) {
|
|
|
- const namespace = name.substring(0, lastIndex);
|
|
|
- if (namespaces.has(namespace)) {
|
|
|
- this._group = namespace;
|
|
|
+ if (node.name) {
|
|
|
+ if (namespaces.has(node.name)) {
|
|
|
+ this._group = node.name;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ const lastIndex = node.name.lastIndexOf('/');
|
|
|
+ if (lastIndex != -1) {
|
|
|
+ const namespace = node.name.substring(0, lastIndex);
|
|
|
+ if (namespaces.has(namespace)) {
|
|
|
+ this._group = namespace;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- if (node) {
|
|
|
+ if (tensors) {
|
|
|
+ for (const tensor of tensors) {
|
|
|
+ this._inputs.push(new tf.Parameter(tensor.name, [
|
|
|
+ new tf.Argument(tensor.value.name, null, tensor.value)
|
|
|
+ ]));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
if (node.device !== undefined) {
|
|
|
this._device = node.device;
|
|
|
}
|
|
|
if (node.attr) {
|
|
|
- this._attributes = Object.keys(node.attr).map((name) => {
|
|
|
- const value = node.attr[name];
|
|
|
- return new tf.Attribute(metadata, op, name, value);
|
|
|
+ this._attributes = Object.entries(node.attr).map((entry) => {
|
|
|
+ return new tf.Attribute(metadata, node.op, entry[0], entry[1]);
|
|
|
});
|
|
|
}
|
|
|
let inputIndex = 0;
|
|
|
- const inputs = node.input.filter((input) => !input.name.startsWith('^'));
|
|
|
+ const inputs = (node.input || []).filter((input) => !input.name.startsWith('^'));
|
|
|
if (this._type && this._type.inputs) {
|
|
|
for (const input of this._type.inputs) {
|
|
|
let inputCount = 1;
|
|
|
@@ -987,7 +989,7 @@ tf.Node = class {
|
|
|
]);
|
|
|
}));
|
|
|
let outputIndex = 0;
|
|
|
- const outputs = node.output;
|
|
|
+ const outputs = node.output || [];
|
|
|
if (this._type && this._type.outputs) {
|
|
|
for (const output of this._type.outputs) {
|
|
|
let outputCount = 1;
|
|
|
@@ -1015,14 +1017,8 @@ tf.Node = class {
|
|
|
new tf.Argument(output.name ? output.name : '-', null, null)
|
|
|
]);
|
|
|
}));
|
|
|
- this._controlDependencies = node.controlDependencies.map((input) => new tf.Argument(input.name));
|
|
|
- }
|
|
|
- else if (tensors) {
|
|
|
- for (const tensor of tensors) {
|
|
|
- this._inputs.push(new tf.Parameter(tensor.name, [
|
|
|
- new tf.Argument(tensor.value.name, null, tensor.value)
|
|
|
- ]));
|
|
|
- }
|
|
|
+ const controlDependencies = node.controlDependencies || [];
|
|
|
+ this._controlDependencies = controlDependencies.map((input) => new tf.Argument(input.name));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -1101,9 +1097,8 @@ tf.Attribute = class {
|
|
|
break;
|
|
|
}
|
|
|
case 'func': {
|
|
|
- const name = value.func.name;
|
|
|
this._type = 'function';
|
|
|
- this._value = metadata.type(name);
|
|
|
+ this._value = new tf.Node(metadata, { op: value.func.name, attr: value.func.attr });
|
|
|
break;
|
|
|
}
|
|
|
case 'list': {
|
|
|
@@ -1127,7 +1122,7 @@ tf.Attribute = class {
|
|
|
}
|
|
|
else if (list.func && list.func.length > 0) {
|
|
|
this._type = 'function[]';
|
|
|
- this._value = list.func.map((func) => metadata.type(func.name));
|
|
|
+ this._value = list.func.map((func) => new tf.Node(metadata, { op: func.name, attr: func.attr }));
|
|
|
}
|
|
|
else {
|
|
|
this._value = [];
|
|
|
@@ -2426,7 +2421,8 @@ tf.Utility = class {
|
|
|
}
|
|
|
}
|
|
|
if (match) {
|
|
|
- node.metadata = metadata;
|
|
|
+ node.metadata = Object.assign({}, metadata);
|
|
|
+ node.metadata.name = node.op;
|
|
|
break;
|
|
|
}
|
|
|
else {
|
|
|
@@ -2473,7 +2469,7 @@ tf.Utility = class {
|
|
|
context.inputs.push(input);
|
|
|
}
|
|
|
for (const node of node_map.values()) {
|
|
|
- context.nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
|
|
|
+ context.nodes.push(new tf.Node(metadata, node, namespaces, initializers));
|
|
|
}
|
|
|
return context;
|
|
|
}
|