|
|
@@ -496,20 +496,11 @@ tf.Model = class {
|
|
|
this._producer = producer || '';
|
|
|
this._graphs = [];
|
|
|
if (model) {
|
|
|
- const graphs = [];
|
|
|
for (let i = 0; i < model.meta_graphs.length; i++) {
|
|
|
const meta_graph = model.meta_graphs[i];
|
|
|
const name = (meta_graph.meta_info_def && meta_graph.meta_info_def.any_info) ? meta_graph.meta_info_def.any_info.toString() : ((model.meta_graphs.length > 1) ? i.toString() : '-');
|
|
|
const graph = new tf.Graph(metadata, meta_graph, name, bundle);
|
|
|
- graphs.push(graph);
|
|
|
- }
|
|
|
- // Recursively add all subgraphs.
|
|
|
- while (graphs.length > 0) {
|
|
|
- const graph = graphs.shift();
|
|
|
this._graphs.push(graph);
|
|
|
- for (const func of graph.functions || []) {
|
|
|
- graphs.push(func);
|
|
|
- }
|
|
|
}
|
|
|
}
|
|
|
else {
|
|
|
@@ -543,10 +534,8 @@ tf.Graph = class {
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
this._nodes = [];
|
|
|
- this._functions = [];
|
|
|
|
|
|
if (meta_graph && meta_graph.graph_def) {
|
|
|
- metadata = new tf.GraphMetadata(metadata, meta_graph.meta_info_def);
|
|
|
const graph = meta_graph.graph_def;
|
|
|
if (graph.versions) {
|
|
|
this._version = 'v' + graph.versions.producer.toString();
|
|
|
@@ -561,6 +550,8 @@ tf.Graph = class {
|
|
|
this._tags = meta_graph.meta_info_def.tags.join(', ');
|
|
|
}
|
|
|
|
|
|
+ metadata = new tf.GraphMetadata(metadata, graph.library);
|
|
|
+
|
|
|
const nodes = graph.node;
|
|
|
if (nodes) {
|
|
|
const node_map = new Map();
|
|
|
@@ -911,15 +902,6 @@ tf.Graph = class {
|
|
|
this._nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- if (graph.library) {
|
|
|
- const funcs = graph.library.function;
|
|
|
- for (const func of funcs) {
|
|
|
- const value = new tf.Function(this, func, metadata);
|
|
|
- metadata.add(value);
|
|
|
- this._functions.push(value);
|
|
|
- }
|
|
|
- }
|
|
|
}
|
|
|
else if (bundle) {
|
|
|
const nodeNames = [];
|
|
|
@@ -986,10 +968,6 @@ tf.Graph = class {
|
|
|
get metadata() {
|
|
|
return this._metadata;
|
|
|
}
|
|
|
-
|
|
|
- get functions() {
|
|
|
- return this._functions;
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
tf.Parameter = class {
|
|
|
@@ -1040,8 +1018,8 @@ tf.Argument = class {
|
|
|
};
|
|
|
|
|
|
tf.Function = class {
|
|
|
+ constructor(metadata, func) {
|
|
|
|
|
|
- constructor(graph, func, metadata) {
|
|
|
this._name = func.signature.name;
|
|
|
this._version = null;
|
|
|
this._tags = null;
|
|
|
@@ -1195,6 +1173,10 @@ tf.Function = class {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ get type() {
|
|
|
+ return 'function';
|
|
|
+ }
|
|
|
+
|
|
|
get name() {
|
|
|
return this._name;
|
|
|
}
|
|
|
@@ -1257,9 +1239,7 @@ tf.Node = class {
|
|
|
if (node.attr) {
|
|
|
this._attributes = Object.keys(node.attr).map((name) => {
|
|
|
const value = node.attr[name];
|
|
|
- const schema = value && value.metadata ? value.metadata : metadata.attribute(op, name);
|
|
|
- const visible = metadata.visible(this._type, name);
|
|
|
- return new tf.Attribute(schema, name, value, visible);
|
|
|
+ return new tf.Attribute(metadata, op, name, value);
|
|
|
});
|
|
|
}
|
|
|
let inputIndex = 0;
|
|
|
@@ -1370,10 +1350,12 @@ tf.Node = class {
|
|
|
|
|
|
tf.Attribute = class {
|
|
|
|
|
|
- constructor(schema, name, value, visible) {
|
|
|
+ constructor(metadata, op, name, value) {
|
|
|
this._name = name;
|
|
|
this._value = null;
|
|
|
this._type = null;
|
|
|
+ const schema = value && value.metadata ? value.metadata : metadata.attribute(op, name);
|
|
|
+ const visible = metadata.visible(op, name);
|
|
|
if (Object.prototype.hasOwnProperty.call(value, 'tensor')) {
|
|
|
this._type = 'tensor';
|
|
|
this._value = new tf.Tensor(value.tensor);
|
|
|
@@ -1403,9 +1385,8 @@ tf.Attribute = class {
|
|
|
this._value = tf.Utility.decodeText(value.s);
|
|
|
break;
|
|
|
case 'func': {
|
|
|
- const func = value.func;
|
|
|
this._type = 'function';
|
|
|
- this._value = func.name;
|
|
|
+ this._value = metadata.type(value.func.name);
|
|
|
break;
|
|
|
}
|
|
|
case 'list': {
|
|
|
@@ -1427,6 +1408,10 @@ tf.Attribute = class {
|
|
|
this._type = 'shape[]';
|
|
|
this._value = list.shape.map((shape) => new tf.TensorShape(shape));
|
|
|
}
|
|
|
+ else if (list.func && list.func.length > 0) {
|
|
|
+ this._type = 'function[]';
|
|
|
+ this._value = list.func.map((func) => metadata.type(func.name));
|
|
|
+ }
|
|
|
else {
|
|
|
this._value = [];
|
|
|
}
|
|
|
@@ -2243,22 +2228,31 @@ tf.EventFileReader = class {
|
|
|
|
|
|
tf.GraphMetadata = class {
|
|
|
|
|
|
- constructor(metadata) {
|
|
|
+ constructor(metadata, library) {
|
|
|
this._metadata = metadata;
|
|
|
this._functions = new Map();
|
|
|
this._attributes = new Map();
|
|
|
this._visibleCache = new Map();
|
|
|
- }
|
|
|
|
|
|
- add(func) {
|
|
|
- if (this._functions.has(func.name)) {
|
|
|
- throw new tf.Error("Duplicate function name '" + func.name + "'.");
|
|
|
+ if (library && Array.isArray(library.function)) {
|
|
|
+ for (const func of library.function) {
|
|
|
+ const name = func.signature.name;
|
|
|
+ if (this._functions.has(func.name)) {
|
|
|
+ throw new tf.Error("Duplicate function name '" + func.name + "'.");
|
|
|
+ }
|
|
|
+ this._functions.set(name, func);
|
|
|
+ }
|
|
|
}
|
|
|
- this._functions.set(func.name, func);
|
|
|
+
|
|
|
}
|
|
|
|
|
|
type(name) {
|
|
|
if (this._functions.has(name)) {
|
|
|
+ const func = this._functions.get(name);
|
|
|
+ if (func instanceof tf.Function) {
|
|
|
+ return func;
|
|
|
+ }
|
|
|
+ this._functions.set(name, new tf.Function(this, func));
|
|
|
return this._functions.get(name);
|
|
|
}
|
|
|
return this._metadata.type(name);
|