|
|
@@ -103,24 +103,29 @@ tf.ModelFactory = class {
|
|
|
}
|
|
|
if (saved_model.meta_graphs.every((meta_graph) => meta_graph.graph_def.node.every((node) => node.op.startsWith('aten::') || node.op.startsWith('prim::') || node.op === 'IO Node'))) {
|
|
|
producer = 'PyTorch';
|
|
|
- return context.request('pytorch-metadata.json', 'utf-8', null).then((data) => {
|
|
|
- const metadata = new Map();
|
|
|
- for (const item of JSON.parse(data)) {
|
|
|
- const index = item.name.indexOf(':');
|
|
|
- const key = (index !== -1) ? item.name.substring(0, index) : item.name;
|
|
|
- const name = key.replace(/^torch\./, 'aten::');
|
|
|
- if (!metadata.has(name)) {
|
|
|
- metadata.set(name, []);
|
|
|
+ const openPyTorchMetadata = (context, saved_model) => {
|
|
|
+ return context.request('pytorch-metadata.json', 'utf-8', null).then((data) => {
|
|
|
+ const metadata = new Map();
|
|
|
+ for (const item of JSON.parse(data)) {
|
|
|
+ const index = item.name.indexOf(':');
|
|
|
+ const key = (index !== -1) ? item.name.substring(0, index) : item.name;
|
|
|
+ const name = key.replace(/^torch\./, 'aten::');
|
|
|
+ if (!metadata.has(name)) {
|
|
|
+ metadata.set(name, []);
|
|
|
+ }
|
|
|
+ metadata.get(name).push(item);
|
|
|
}
|
|
|
- metadata.get(name).push(item);
|
|
|
- }
|
|
|
- for (const meta_graph of saved_model.meta_graphs) {
|
|
|
- for (const node of meta_graph.graph_def.node) {
|
|
|
- node.__metadata__ = Array.from(metadata.get(node.op) || []);
|
|
|
+ for (const meta_graph of saved_model.meta_graphs) {
|
|
|
+ for (const node of meta_graph.graph_def.node) {
|
|
|
+ node.__metadata__ = Array.from(metadata.get(node.op) || []);
|
|
|
+ }
|
|
|
}
|
|
|
- }
|
|
|
- return openModel(saved_model, format, producer, null);
|
|
|
- }).catch(() => {
|
|
|
+ return saved_model;
|
|
|
+ }).catch(() => {
|
|
|
+ return saved_model;
|
|
|
+ });
|
|
|
+ };
|
|
|
+ return openPyTorchMetadata(context, saved_model).then((saved_model) => {
|
|
|
return openModel(saved_model, format, producer, null);
|
|
|
});
|
|
|
}
|
|
|
@@ -504,9 +509,9 @@ tf.Model = class {
|
|
|
}
|
|
|
}
|
|
|
else {
|
|
|
- this._graphs.push(new tf.Graph(metadata, null, '', bundle));
|
|
|
+ const graph = new tf.Graph(metadata, null, '', bundle);
|
|
|
+ this._graphs.push(graph);
|
|
|
}
|
|
|
-
|
|
|
}
|
|
|
|
|
|
get format() {
|
|
|
@@ -529,11 +534,11 @@ tf.Model = class {
|
|
|
tf.Graph = class {
|
|
|
|
|
|
constructor(metadata, meta_graph, name, bundle) {
|
|
|
- this._version = null;
|
|
|
this._name = name;
|
|
|
this._inputs = [];
|
|
|
this._outputs = [];
|
|
|
this._nodes = [];
|
|
|
+ this._version = null;
|
|
|
|
|
|
if (meta_graph && meta_graph.graph_def) {
|
|
|
const graph = meta_graph.graph_def;
|
|
|
@@ -549,359 +554,12 @@ tf.Graph = class {
|
|
|
if (meta_graph.meta_info_def && meta_graph.meta_info_def.tags) {
|
|
|
this._tags = meta_graph.meta_info_def.tags.join(', ');
|
|
|
}
|
|
|
-
|
|
|
metadata = new tf.GraphMetadata(metadata, graph.library);
|
|
|
-
|
|
|
- const nodes = graph.node;
|
|
|
- if (nodes) {
|
|
|
- const node_map = new Map();
|
|
|
- const namespaces = new Set();
|
|
|
- for (const node of nodes) {
|
|
|
- const nodeName = node.name;
|
|
|
- node_map.set(nodeName, node);
|
|
|
- if (node.op != 'Const') {
|
|
|
- const index = nodeName.lastIndexOf('/');
|
|
|
- if (index != -1) {
|
|
|
- const namespace = nodeName.substring(0, index);
|
|
|
- namespaces.add(namespace);
|
|
|
- }
|
|
|
- }
|
|
|
- node.output = [];
|
|
|
- }
|
|
|
- for (const node of nodes) {
|
|
|
- const inputs = node.input;
|
|
|
- node.input = [];
|
|
|
- node.controlDependencies = [];
|
|
|
- for (const input of inputs) {
|
|
|
- const split = input.split(':', 2);
|
|
|
- const input_name = split[0];
|
|
|
- const input_index = split.length == 1 ? 0 : parseInt(split[1]);
|
|
|
- const from_name = input_name.startsWith('^') ? input_name.substring(1) : input_name;
|
|
|
- const from = node_map.get(from_name);
|
|
|
- const output_name = input_index == 0 ? from_name : from_name + ':' + input_index.toString();
|
|
|
- const input_arg = from ? { name: output_name, from: from } : { name: output_name };
|
|
|
- if (input_name.startsWith('^')) {
|
|
|
- node.controlDependencies.push(input_arg);
|
|
|
- }
|
|
|
- else {
|
|
|
- node.input.push(input_arg);
|
|
|
- }
|
|
|
- if (from) {
|
|
|
- for (let i = from.output.length; i <= input_index; i++) {
|
|
|
- from.output.push({ name: i === 0 ? from_name : from_name + ':' + i.toString(), to: [] });
|
|
|
- }
|
|
|
- from.output[input_index].to.push(node);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- const initializers = new Map();
|
|
|
- const map_tensor = (name, node, kind) => {
|
|
|
- if (node && node.op === 'Const' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const value = node.attr.value;
|
|
|
- if (value && Object.prototype.hasOwnProperty.call(value, 'tensor')) {
|
|
|
- const tensor = new tf.Tensor(value.tensor, name, kind);
|
|
|
- return new tf.Argument(name, tensor.type, tensor);
|
|
|
- }
|
|
|
- }
|
|
|
- return null;
|
|
|
- };
|
|
|
- const map_resource = (name, node, tensor) => {
|
|
|
- if (node.op === 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const dtype = node.attr.dtype.type;
|
|
|
- if (dtype === tf.proto.tensorflow.DataType.DT_RESOURCE) {
|
|
|
- return new tf.Argument(name, null, tensor);
|
|
|
- }
|
|
|
- }
|
|
|
- return null;
|
|
|
- };
|
|
|
- for (const node of node_map.values()) {
|
|
|
- if (node.op === 'Identity' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const initializer = map_tensor(node.name, node.input[0].from, 'Identity Constant');
|
|
|
- if (initializer) {
|
|
|
- initializers.set(initializer.name, initializer);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- node_map.delete(node.input[0].name);
|
|
|
- }
|
|
|
- const identity = node.input[0].from;
|
|
|
- if (identity && identity.op === 'Identity' && identity.input.length === 1 && identity.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const initializer = map_tensor(node.name, identity.input[0].from, 'Identity Constant');
|
|
|
- if (initializer) {
|
|
|
- initializers.set(initializer.name, initializer);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- node_map.delete(identity.name);
|
|
|
- node_map.delete(node.name);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- for (const node of node_map.values()) {
|
|
|
- const initializer = map_tensor(node.name, node, 'Const');
|
|
|
- if (initializer) {
|
|
|
- initializers.set(initializer.name, initializer);
|
|
|
- node_map.delete(node.name);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- }
|
|
|
- }
|
|
|
- for (const node of node_map.values()) {
|
|
|
- if (node.op === 'ReadVariableOp' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- if (node.attr && node.attr.dtype && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape) {
|
|
|
- const tensor = new tf.proto.tensorflow.TensorProto();
|
|
|
- tensor.dtype = node.attr.dtype.type;
|
|
|
- tensor.tensor_shape = node.attr._output_shapes.list.shape[0];
|
|
|
- const initializer = map_resource(node.name, node.input[0].from, new tf.Tensor(tensor, name, 'Resource Variable'));
|
|
|
- if (initializer) {
|
|
|
- initializers.set(initializer.name, initializer);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- node_map.delete(node.input[0].name);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- const input_map = new Map();
|
|
|
- for (const node of node_map.values()) {
|
|
|
- if (node.op == 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const dtype = node.attr.dtype;
|
|
|
- const shape = node.attr.shape;
|
|
|
- if (dtype && dtype.type && shape && shape.shape) {
|
|
|
- const name = node.name;
|
|
|
- const type = new tf.TensorType(dtype.type, shape.shape);
|
|
|
- const argument = new tf.Argument(name, type, null);
|
|
|
- input_map.set(name, new tf.Parameter(name, [ argument ]));
|
|
|
- node_map.delete(name);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- const updatePyTorch = (node_map) => {
|
|
|
- for (const node of node_map.values()) {
|
|
|
- if (node.op === 'prim::Constant' && node.input.length === 0 && node.controlDependencies.length === 0 && node.attr && Object.keys(node.attr).length === 1 && node.attr.attr && node.attr.attr.s) {
|
|
|
- const value = tf.Utility.decodeText(node.attr.attr.s);
|
|
|
- const match = /{\s*value\s*:\s*(.*)\s*}/.exec(value);
|
|
|
- if (match) {
|
|
|
- node.value = match[1].trim();
|
|
|
- }
|
|
|
- const empty = /{\s*}/.exec(value);
|
|
|
- if (empty) {
|
|
|
- node.value = null;
|
|
|
- }
|
|
|
- }
|
|
|
- if (node.op === 'prim::GetAttr' && node.input.length === 1 && node.controlDependencies.length === 0 && node.attr && Object.keys(node.attr).length === 1 && node.attr.attr && node.attr.attr.s) {
|
|
|
- const value = tf.Utility.decodeText(node.attr.attr.s);
|
|
|
- const match = /{\s*name\s*:\s*([A-za-z0-9_]*)\s*}/.exec(value);
|
|
|
- if (match) {
|
|
|
- node.value = match[1].trim();
|
|
|
- }
|
|
|
- }
|
|
|
- if (node.op === 'IO Node' && node.controlDependencies.length === 0) {
|
|
|
- const shape = node.attr && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape ? node.attr._output_shapes.list.shape[0] : null;
|
|
|
- const type = shape ? new tf.TensorType('?', shape) : null;
|
|
|
- if (node.input.length === 0 && node.output.length === 1) {
|
|
|
- this._inputs.push(new tf.Parameter(node.name, [
|
|
|
- new tf.Argument(node.output[0].name, type, null)
|
|
|
- ]));
|
|
|
- node_map.delete(node.name);
|
|
|
- }
|
|
|
- if (node.input.length === 1 && node.output.length === 0) {
|
|
|
- this._outputs.push(new tf.Parameter(node.name, [
|
|
|
- new tf.Argument(node.input[0].name, type, null)
|
|
|
- ]));
|
|
|
- node_map.delete(node.name);
|
|
|
- }
|
|
|
- }
|
|
|
- if (Object.keys(node.attr).length === 2 &&
|
|
|
- node.attr.attr && node.attr.attr.s && node.attr._output_shapes) {
|
|
|
- const value = tf.Utility.decodeText(node.attr.attr.s);
|
|
|
- if (/\s*/.exec(value) || /{\s*}/.exec(value)) {
|
|
|
- node.attr = {};
|
|
|
- delete node._output_shapes;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- const remove_input = (input, node) => {
|
|
|
- const from = input.from;
|
|
|
- if (from) {
|
|
|
- for (const output of from.output) {
|
|
|
- output.to = output.to.filter((to) => to !== node);
|
|
|
- }
|
|
|
- if (from.output.every((output) => output.to.length === 0) && from.controlDependencies.length === 0) {
|
|
|
- from.remove = true;
|
|
|
- }
|
|
|
- delete input.from;
|
|
|
- }
|
|
|
- };
|
|
|
- for (const node of node_map.values()) {
|
|
|
- if (node.op === 'prim::ListConstruct' && node.input.every((input) => input.from.value !== undefined) && node.controlDependencies.length === 0) {
|
|
|
- node.value = node.input.map((input) => input.from.value);
|
|
|
- for (const input of node.input) {
|
|
|
- remove_input(input, node);
|
|
|
- }
|
|
|
- node.input = [];
|
|
|
- }
|
|
|
- }
|
|
|
- for (const node of node_map.values()) {
|
|
|
- const remove = new Set();
|
|
|
- for (let i = 0; i < node.input.length; i++) {
|
|
|
- const input = node.input[i];
|
|
|
- const from = input.from;
|
|
|
- if (from) {
|
|
|
- if (from.op === 'prim::GetAttr' && from.input.length === 1 && from.output.length === 1 && from.controlDependencies.length === 0 && from.value !== undefined) {
|
|
|
- remove_input(input, node);
|
|
|
- input.label = from.value;
|
|
|
- const tensor = new tf.Tensor(null, input.name, from.op);
|
|
|
- const argument = new tf.Argument(input.name, null, tensor);
|
|
|
- initializers.set(input.name, argument);
|
|
|
- }
|
|
|
- if (from.op === 'prim::Constant' && from.input.length === 0 && from.controlDependencies.length === 0 && from.value !== undefined) {
|
|
|
- input.constant = from.value;
|
|
|
- remove_input(input, node);
|
|
|
- remove.add(input.name);
|
|
|
- }
|
|
|
- if (from.op === 'prim::ListConstruct' && from.output.length === 1 && from.controlDependencies.length === 0 && from.value !== undefined) {
|
|
|
- input.list = from.value;
|
|
|
- remove_input(input, node);
|
|
|
- remove.add(input.name);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- if (node.__metadata__) {
|
|
|
- for (const metadata of node.__metadata__) {
|
|
|
- const parameters = Array.prototype.slice.call(metadata.inputs || []).concat(Array.prototype.slice.call(metadata.attributes || []));
|
|
|
- let match = true;
|
|
|
- const inputs = Array.from(node.input);
|
|
|
- if (inputs.length > parameters.length) {
|
|
|
- match = false;
|
|
|
- }
|
|
|
- while (inputs.length > 0 && match) {
|
|
|
- match = false;
|
|
|
- const input = inputs.shift();
|
|
|
- delete input.metadata;
|
|
|
- const parameter = parameters.shift();
|
|
|
- switch (parameter.type) {
|
|
|
- case 'Tensor': {
|
|
|
- if ((input.constant === undefined && input.list === undefined) || input.constant === null) {
|
|
|
- input.metadata = parameter;
|
|
|
- match = true;
|
|
|
- }
|
|
|
- else {
|
|
|
- inputs.unshift(input);
|
|
|
- match = true;
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'int64': {
|
|
|
- const value = parseInt(input.constant);
|
|
|
- if (input.constant !== undefined && Number.isInteger(value)) {
|
|
|
- input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
- input.attr.i = value;
|
|
|
- input.attr.metadata = parameter;
|
|
|
- match = true;
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'float32': {
|
|
|
- const value = parseFloat(input.constant);
|
|
|
- if (input.constant !== undefined && !isNaN(value)) {
|
|
|
- input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
- input.attr.f = value;
|
|
|
- input.attr.metadata = parameter;
|
|
|
- match = true;
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'int64[]': {
|
|
|
- if (Array.isArray(input.list)) {
|
|
|
- const list = input.list.map((item) => parseInt(item));
|
|
|
- if (list.every((value) => Number.isInteger(value))) {
|
|
|
- input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
- input.attr.list = new tf.proto.tensorflow.ListValue();
|
|
|
- input.attr.list.i = list;
|
|
|
- input.attr.metadata = parameter;
|
|
|
- match = true;
|
|
|
- }
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'boolean': {
|
|
|
- if (input.constant === 'false' || input.constant === '0') {
|
|
|
- input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
- input.attr.b = false;
|
|
|
- input.attr.metadata = parameter;
|
|
|
- match = true;
|
|
|
- }
|
|
|
- else if (input.constant === 'true' || input.constant === '1') {
|
|
|
- input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
- input.attr.b = true;
|
|
|
- input.attr.metadata = parameter;
|
|
|
- match = true;
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- case 'Scalar': {
|
|
|
- const value = parseInt(input.constant);
|
|
|
- if (input.constant !== undefined && Number.isInteger(value)) {
|
|
|
- input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
- input.attr.i = value;
|
|
|
- input.attr.metadata = parameter;
|
|
|
- match = true;
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- default:
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- if (match) {
|
|
|
- node.metadata = metadata;
|
|
|
- break;
|
|
|
- }
|
|
|
- else {
|
|
|
- for (const input of node.input) {
|
|
|
- delete input.metadata;
|
|
|
- delete input.attr;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- node.input = node.input.filter((input, index) => {
|
|
|
- if (input.attr) {
|
|
|
- const name = input.attr.metadata ? input.attr.metadata.name : index.toString();
|
|
|
- node.attr[name] = input.attr;
|
|
|
- }
|
|
|
- else if (input.constant !== undefined && input.constant !== null) {
|
|
|
- const attr = new tf.proto.tensorflow.AttrValue();
|
|
|
- attr.s = input.constant;
|
|
|
- node.attr[index.toString()] = attr;
|
|
|
- }
|
|
|
- else if (input.list !== undefined) {
|
|
|
- const attr = new tf.proto.tensorflow.AttrValue();
|
|
|
- attr.list = new tf.proto.tensorflow.ListValue();
|
|
|
- attr.list.s = input.list;
|
|
|
- node.attr[index.toString()] = attr;
|
|
|
- }
|
|
|
- return !remove.has(input.name);
|
|
|
- });
|
|
|
- }
|
|
|
- for (const node of node_map.values()) {
|
|
|
- if (node.op === 'prim::GetAttr' && node.remove) {
|
|
|
- node_map.delete(node.name);
|
|
|
- }
|
|
|
- if (node.op === 'prim::Constant' && node.remove) {
|
|
|
- node_map.delete(node.name);
|
|
|
- }
|
|
|
- if (node.op === 'prim::ListConstruct' && node.remove) {
|
|
|
- node_map.delete(node.name);
|
|
|
- }
|
|
|
- }
|
|
|
- };
|
|
|
- updatePyTorch(node_map);
|
|
|
- for (const input of input_map.values()) {
|
|
|
- this._inputs.push(input);
|
|
|
- }
|
|
|
- for (const node of node_map.values()) {
|
|
|
- this._nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
|
|
|
- }
|
|
|
- }
|
|
|
+ const nodes = graph.node || [];
|
|
|
+ const context = tf.Utility.createGraph(metadata, nodes);
|
|
|
+ this._nodes = context.nodes;
|
|
|
+ this._inputs = context.inputs;
|
|
|
+ this._outputs = context.outputs;
|
|
|
}
|
|
|
else if (bundle) {
|
|
|
const nodeNames = [];
|
|
|
@@ -1053,124 +711,11 @@ tf.Function = class {
|
|
|
output_arg_map.set(name, output.name);
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- const namespaces = new Set();
|
|
|
- const nodes = func.node_def;
|
|
|
- if (nodes) {
|
|
|
- const node_map = new Map();
|
|
|
- for (const node of nodes) {
|
|
|
- const nodeName = node.name;
|
|
|
- node_map.set(nodeName, node);
|
|
|
- if (node.op != 'Const') {
|
|
|
- const lastIndex = nodeName.lastIndexOf('/');
|
|
|
- if (lastIndex != -1) {
|
|
|
- const namespace = nodeName.substring(0, lastIndex);
|
|
|
- namespaces.add(namespace);
|
|
|
- }
|
|
|
- }
|
|
|
- node.output = [];
|
|
|
- }
|
|
|
- for (const node of nodes) {
|
|
|
- const inputs = node.input;
|
|
|
- node.input = [];
|
|
|
- node.controlDependencies = [];
|
|
|
- for (const input of inputs) {
|
|
|
- const split = input.split(':', 3);
|
|
|
- const input_name = split[0];
|
|
|
- const input_index = split.length == 1 ? 0 : parseInt(split[split.length - 1]);
|
|
|
- const from_name = input_name.startsWith('^') ? input_name.substring(1) : input_name;
|
|
|
- const from = node_map.get(from_name);
|
|
|
- const output_name = from_name + (input_index == 0 ? '' : ':' + input_index.toString());
|
|
|
- const input_arg = from ? { name: output_name, from: from } : { name: output_name };
|
|
|
- if (input_name.startsWith('^')) {
|
|
|
- node.controlDependencies.push(input_arg);
|
|
|
- }
|
|
|
- else {
|
|
|
- node.input.push(input_arg);
|
|
|
- }
|
|
|
- if (from) {
|
|
|
- for (let i = from.output.length; i <= input_index; i++) {
|
|
|
- from.output.push({ name: i === 0 ? from_name : from_name + ':' + i.toString(), to: [] });
|
|
|
- }
|
|
|
- from.output[input_index].to.push(node);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- for (const node of nodes) {
|
|
|
- if (output_arg_map.has(node.name)) {
|
|
|
- node.output.push({ name: node.name, to: [] });
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- const initializers = new Map();
|
|
|
- const map_tensor = (name, node, kind) => {
|
|
|
- if (node && node.op === 'Const' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const value = node.attr.value;
|
|
|
- if (value && Object.prototype.hasOwnProperty.call(value, 'tensor')) {
|
|
|
- const tensor = new tf.Tensor(value.tensor, name, kind);
|
|
|
- return new tf.Argument(name, tensor.type, tensor);
|
|
|
- }
|
|
|
- }
|
|
|
- return null;
|
|
|
- };
|
|
|
- const map_resource = (name, node, tensor) => {
|
|
|
- if (node && node.op === 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const dtype = node.attr.dtype.type;
|
|
|
- if (dtype === tf.proto.tensorflow.DataType.DT_RESOURCE) {
|
|
|
- return new tf.Argument(name, null, tensor);
|
|
|
- }
|
|
|
- }
|
|
|
- return null;
|
|
|
- };
|
|
|
- for (const node of node_map.values()) {
|
|
|
- if (node.op === 'Identity' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const initializer = map_tensor(node.name, node.input[0].from, 'Identity Constant');
|
|
|
- if (initializer) {
|
|
|
- initializers.set(initializer.name, initializer);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- node_map.delete(node.input[0].name);
|
|
|
- }
|
|
|
- const identity = node.input[0];
|
|
|
- if (identity.op === 'Identity' && identity.input.length === 1 && identity.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
- const initializer = map_tensor(node.name, identity.input[0], 'Identity Constant');
|
|
|
- if (initializer) {
|
|
|
- initializers.set(initializer.name, initializer);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- node_map.delete(identity.name);
|
|
|
- node_map.delete(node.name);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- for (const node of node_map.values()) {
|
|
|
- const initializer = map_tensor(node.name, node, 'Const');
|
|
|
- if (initializer) {
|
|
|
- initializers.set(initializer.name, initializer);
|
|
|
- node_map.delete(node.name);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- }
|
|
|
- }
|
|
|
- for (const node of node_map.values()) {
|
|
|
- if (node.op === 'ReadVariableOp' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0 &&
|
|
|
- node.attr && node.attr && node.attr.dtype && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape) {
|
|
|
- const tensor = new tf.proto.tensorflow.TensorProto();
|
|
|
- tensor.dtype = node.attr.dtype.type;
|
|
|
- tensor.tensor_shape = node.attr._output_shapes.list.shape[0];
|
|
|
- const initializer = map_resource(node.name, node.input[0], new tf.Tensor(tensor, name, 'Resource Variable'));
|
|
|
- if (initializer) {
|
|
|
- initializers.set(initializer.name, initializer);
|
|
|
- node_map.delete(initializer.name);
|
|
|
- node_map.delete(node.input[0].name);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- for (const node of node_map.values()) {
|
|
|
- this._nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
|
|
|
- }
|
|
|
- }
|
|
|
+ const nodes = func.node_def || [];
|
|
|
+ const context = tf.Utility.createGraph(metadata, nodes, output_arg_map);
|
|
|
+ this._nodes = context.nodes;
|
|
|
+ this._inputs = this._inputs.concat(context.inputs);
|
|
|
+ this._outputs = this._outputs.concat(context.outputs);
|
|
|
}
|
|
|
|
|
|
get type() {
|
|
|
@@ -2381,7 +1926,368 @@ tf.Utility = class {
|
|
|
tf.Utility._dataTypeKeys = dataTypeKeys;
|
|
|
}
|
|
|
return tf.Utility._dataTypeKeys.get(type);
|
|
|
+ }
|
|
|
|
|
|
+ static createGraph(metadata, nodes, output_arg_map) {
|
|
|
+ const context = {};
|
|
|
+ context.inputs = [];
|
|
|
+ context.outputs = [];
|
|
|
+ context.nodes = [];
|
|
|
+ const namespaces = new Set();
|
|
|
+ const node_map = new Map();
|
|
|
+ for (const node of nodes) {
|
|
|
+ const nodeName = node.name;
|
|
|
+ node_map.set(nodeName, node);
|
|
|
+ if (node.op != 'Const') {
|
|
|
+ const index = nodeName.lastIndexOf('/');
|
|
|
+ if (index != -1) {
|
|
|
+ const namespace = nodeName.substring(0, index);
|
|
|
+ namespaces.add(namespace);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ node.output = [];
|
|
|
+ }
|
|
|
+ for (const node of nodes) {
|
|
|
+ const inputs = node.input;
|
|
|
+ node.input = [];
|
|
|
+ node.controlDependencies = [];
|
|
|
+ for (const input of inputs) {
|
|
|
+ const split = input.split(':', 3);
|
|
|
+ const input_name = split[0];
|
|
|
+ const input_index = split.length == 1 ? 0 : parseInt(split[split.length - 1]);
|
|
|
+ const from_name = input_name.startsWith('^') ? input_name.substring(1) : input_name;
|
|
|
+ const from = node_map.get(from_name);
|
|
|
+ const output_name = input_index == 0 ? from_name : from_name + ':' + input_index.toString();
|
|
|
+ const input_arg = from ? { name: output_name, from: from } : { name: output_name };
|
|
|
+ if (input_name.startsWith('^')) {
|
|
|
+ node.controlDependencies.push(input_arg);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ node.input.push(input_arg);
|
|
|
+ }
|
|
|
+ if (from) {
|
|
|
+ for (let i = from.output.length; i <= input_index; i++) {
|
|
|
+ from.output.push({ name: i === 0 ? from_name : from_name + ':' + i.toString(), to: [] });
|
|
|
+ }
|
|
|
+ from.output[input_index].to.push(node);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (output_arg_map) {
|
|
|
+ for (const node of nodes) {
|
|
|
+ if (output_arg_map.has(node.name)) {
|
|
|
+ node.output.push({ name: node.name, to: [] });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const initializers = new Map();
|
|
|
+ const map_tensor = (name, node, kind) => {
|
|
|
+ if (node && node.op === 'Const' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
+ const value = node.attr.value;
|
|
|
+ if (value && Object.prototype.hasOwnProperty.call(value, 'tensor')) {
|
|
|
+ const tensor = new tf.Tensor(value.tensor, name, kind);
|
|
|
+ return new tf.Argument(name, tensor.type, tensor);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ };
|
|
|
+ const map_resource = (name, node, tensor) => {
|
|
|
+ if (node && node.op === 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
+ const dtype = node.attr.dtype.type;
|
|
|
+ if (dtype === tf.proto.tensorflow.DataType.DT_RESOURCE) {
|
|
|
+ return new tf.Argument(name, null, tensor);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ };
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ if (node.op === 'Identity' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
+ const initializer = map_tensor(node.name, node.input[0].from, 'Identity Constant');
|
|
|
+ if (initializer) {
|
|
|
+ initializers.set(initializer.name, initializer);
|
|
|
+ node_map.delete(initializer.name);
|
|
|
+ node_map.delete(node.input[0].name);
|
|
|
+ }
|
|
|
+ const identity = node.input[0].from;
|
|
|
+ if (identity && identity.op === 'Identity' && identity.input.length === 1 && identity.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
+ const initializer = map_tensor(node.name, identity.input[0].from, 'Identity Constant');
|
|
|
+ if (initializer) {
|
|
|
+ initializers.set(initializer.name, initializer);
|
|
|
+ node_map.delete(initializer.name);
|
|
|
+ node_map.delete(initializer.name);
|
|
|
+ node_map.delete(identity.name);
|
|
|
+ node_map.delete(node.name);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ const initializer = map_tensor(node.name, node, 'Const');
|
|
|
+ if (initializer) {
|
|
|
+ initializers.set(initializer.name, initializer);
|
|
|
+ node_map.delete(node.name);
|
|
|
+ node_map.delete(initializer.name);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ if (node.op === 'ReadVariableOp' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
+ if (node.attr && node.attr.dtype && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape) {
|
|
|
+ const tensor = new tf.proto.tensorflow.TensorProto();
|
|
|
+ tensor.dtype = node.attr.dtype.type;
|
|
|
+ tensor.tensor_shape = node.attr._output_shapes.list.shape[0];
|
|
|
+ const name = node.name;
|
|
|
+ const initializer = map_resource(name, node.input[0].from, new tf.Tensor(tensor, name, 'Resource Variable'));
|
|
|
+ if (initializer) {
|
|
|
+ initializers.set(initializer.name, initializer);
|
|
|
+ node_map.delete(initializer.name);
|
|
|
+ node_map.delete(node.input[0].name);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const input_map = new Map();
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ if (node.op == 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
|
|
|
+ const dtype = node.attr.dtype;
|
|
|
+ const shape = node.attr.shape;
|
|
|
+ if (dtype && dtype.type && shape && shape.shape) {
|
|
|
+ const name = node.name;
|
|
|
+ const type = new tf.TensorType(dtype.type, shape.shape);
|
|
|
+ const argument = new tf.Argument(name, type, null);
|
|
|
+ input_map.set(name, new tf.Parameter(name, [ argument ]));
|
|
|
+ node_map.delete(name);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const updatePyTorch = (node_map) => {
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ if (node.op === 'prim::Constant' && node.input.length === 0 && node.controlDependencies.length === 0 && node.attr && Object.keys(node.attr).length === 1 && node.attr.attr && node.attr.attr.s) {
|
|
|
+ const value = tf.Utility.decodeText(node.attr.attr.s);
|
|
|
+ const match = /{\s*value\s*:\s*(.*)\s*}/.exec(value);
|
|
|
+ if (match) {
|
|
|
+ node.value = match[1].trim();
|
|
|
+ }
|
|
|
+ const empty = /{\s*}/.exec(value);
|
|
|
+ if (empty) {
|
|
|
+ node.value = null;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (node.op === 'prim::GetAttr' && node.input.length === 1 && node.controlDependencies.length === 0 && node.attr && Object.keys(node.attr).length === 1 && node.attr.attr && node.attr.attr.s) {
|
|
|
+ const value = tf.Utility.decodeText(node.attr.attr.s);
|
|
|
+ const match = /{\s*name\s*:\s*([A-za-z0-9_]*)\s*}/.exec(value);
|
|
|
+ if (match) {
|
|
|
+ node.value = match[1].trim();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (node.op === 'IO Node' && node.controlDependencies.length === 0) {
|
|
|
+ const shape = node.attr && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape ? node.attr._output_shapes.list.shape[0] : null;
|
|
|
+ const type = shape ? new tf.TensorType('?', shape) : null;
|
|
|
+ if (node.input.length === 0 && node.output.length === 1) {
|
|
|
+ context.inputs.push(new tf.Parameter(node.name, [
|
|
|
+ new tf.Argument(node.output[0].name, type, null)
|
|
|
+ ]));
|
|
|
+ node_map.delete(node.name);
|
|
|
+ }
|
|
|
+ if (node.input.length === 1 && node.output.length === 0) {
|
|
|
+ context.outputs.push(new tf.Parameter(node.name, [
|
|
|
+ new tf.Argument(node.input[0].name, type, null)
|
|
|
+ ]));
|
|
|
+ node_map.delete(node.name);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (Object.keys(node.attr).length === 2 &&
|
|
|
+ node.attr.attr && node.attr.attr.s && node.attr._output_shapes) {
|
|
|
+ const value = tf.Utility.decodeText(node.attr.attr.s);
|
|
|
+ if (/\s*/.exec(value) || /{\s*}/.exec(value)) {
|
|
|
+ node.attr = {};
|
|
|
+ delete node._output_shapes;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const remove_input = (input, node) => {
|
|
|
+ const from = input.from;
|
|
|
+ if (from) {
|
|
|
+ for (const output of from.output) {
|
|
|
+ output.to = output.to.filter((to) => to !== node);
|
|
|
+ }
|
|
|
+ if (from.output.every((output) => output.to.length === 0) && from.controlDependencies.length === 0) {
|
|
|
+ from.remove = true;
|
|
|
+ }
|
|
|
+ delete input.from;
|
|
|
+ }
|
|
|
+ };
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ if (node.op === 'prim::ListConstruct' && node.input.every((input) => input.from.value !== undefined) && node.controlDependencies.length === 0) {
|
|
|
+ node.value = node.input.map((input) => input.from.value);
|
|
|
+ for (const input of node.input) {
|
|
|
+ remove_input(input, node);
|
|
|
+ }
|
|
|
+ node.input = [];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ const remove = new Set();
|
|
|
+ for (let i = 0; i < node.input.length; i++) {
|
|
|
+ const input = node.input[i];
|
|
|
+ const from = input.from;
|
|
|
+ if (from) {
|
|
|
+ if (from.op === 'prim::GetAttr' && from.input.length === 1 && from.output.length === 1 && from.controlDependencies.length === 0 && from.value !== undefined) {
|
|
|
+ remove_input(input, node);
|
|
|
+ input.label = from.value;
|
|
|
+ const tensor = new tf.Tensor(null, input.name, from.op);
|
|
|
+ const argument = new tf.Argument(input.name, null, tensor);
|
|
|
+ initializers.set(input.name, argument);
|
|
|
+ }
|
|
|
+ if (from.op === 'prim::Constant' && from.input.length === 0 && from.controlDependencies.length === 0 && from.value !== undefined) {
|
|
|
+ input.constant = from.value;
|
|
|
+ remove_input(input, node);
|
|
|
+ remove.add(input.name);
|
|
|
+ }
|
|
|
+ if (from.op === 'prim::ListConstruct' && from.output.length === 1 && from.controlDependencies.length === 0 && from.value !== undefined) {
|
|
|
+ input.list = from.value;
|
|
|
+ remove_input(input, node);
|
|
|
+ remove.add(input.name);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (node.__metadata__) {
|
|
|
+ for (const metadata of node.__metadata__) {
|
|
|
+ const parameters = Array.prototype.slice.call(metadata.inputs || []).concat(Array.prototype.slice.call(metadata.attributes || []));
|
|
|
+ let match = true;
|
|
|
+ const inputs = Array.from(node.input);
|
|
|
+ if (inputs.length > parameters.length) {
|
|
|
+ match = false;
|
|
|
+ }
|
|
|
+ while (inputs.length > 0 && match) {
|
|
|
+ match = false;
|
|
|
+ const input = inputs.shift();
|
|
|
+ delete input.metadata;
|
|
|
+ const parameter = parameters.shift();
|
|
|
+ switch (parameter.type) {
|
|
|
+ case 'Tensor': {
|
|
|
+ if ((input.constant === undefined && input.list === undefined) || input.constant === null) {
|
|
|
+ input.metadata = parameter;
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ inputs.unshift(input);
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'int64': {
|
|
|
+ const value = parseInt(input.constant);
|
|
|
+ if (input.constant !== undefined && Number.isInteger(value)) {
|
|
|
+ input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
+ input.attr.i = value;
|
|
|
+ input.attr.metadata = parameter;
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'float32': {
|
|
|
+ const value = parseFloat(input.constant);
|
|
|
+ if (input.constant !== undefined && !isNaN(value)) {
|
|
|
+ input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
+ input.attr.f = value;
|
|
|
+ input.attr.metadata = parameter;
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'int64[]': {
|
|
|
+ if (Array.isArray(input.list)) {
|
|
|
+ const list = input.list.map((item) => parseInt(item));
|
|
|
+ if (list.every((value) => Number.isInteger(value))) {
|
|
|
+ input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
+ input.attr.list = new tf.proto.tensorflow.ListValue();
|
|
|
+ input.attr.list.i = list;
|
|
|
+ input.attr.metadata = parameter;
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'boolean': {
|
|
|
+ if (input.constant === 'false' || input.constant === '0') {
|
|
|
+ input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
+ input.attr.b = false;
|
|
|
+ input.attr.metadata = parameter;
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ else if (input.constant === 'true' || input.constant === '1') {
|
|
|
+ input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
+ input.attr.b = true;
|
|
|
+ input.attr.metadata = parameter;
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'Scalar': {
|
|
|
+ const value = parseInt(input.constant);
|
|
|
+ if (input.constant !== undefined && Number.isInteger(value)) {
|
|
|
+ input.attr = new tf.proto.tensorflow.AttrValue();
|
|
|
+ input.attr.i = value;
|
|
|
+ input.attr.metadata = parameter;
|
|
|
+ match = true;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (match) {
|
|
|
+ node.metadata = metadata;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ for (const input of node.input) {
|
|
|
+ delete input.metadata;
|
|
|
+ delete input.attr;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ node.input = node.input.filter((input, index) => {
|
|
|
+ if (input.attr) {
|
|
|
+ const name = input.attr.metadata ? input.attr.metadata.name : index.toString();
|
|
|
+ node.attr[name] = input.attr;
|
|
|
+ }
|
|
|
+ else if (input.constant !== undefined && input.constant !== null) {
|
|
|
+ const attr = new tf.proto.tensorflow.AttrValue();
|
|
|
+ attr.s = input.constant;
|
|
|
+ node.attr[index.toString()] = attr;
|
|
|
+ }
|
|
|
+ else if (input.list !== undefined) {
|
|
|
+ const attr = new tf.proto.tensorflow.AttrValue();
|
|
|
+ attr.list = new tf.proto.tensorflow.ListValue();
|
|
|
+ attr.list.s = input.list;
|
|
|
+ node.attr[index.toString()] = attr;
|
|
|
+ }
|
|
|
+ return !remove.has(input.name);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ if (node.op === 'prim::GetAttr' && node.remove) {
|
|
|
+ node_map.delete(node.name);
|
|
|
+ }
|
|
|
+ if (node.op === 'prim::Constant' && node.remove) {
|
|
|
+ node_map.delete(node.name);
|
|
|
+ }
|
|
|
+ if (node.op === 'prim::ListConstruct' && node.remove) {
|
|
|
+ node_map.delete(node.name);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+ updatePyTorch(node_map);
|
|
|
+ for (const input of input_map.values()) {
|
|
|
+ context.inputs.push(input);
|
|
|
+ }
|
|
|
+ for (const node of node_map.values()) {
|
|
|
+ context.nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
|
|
|
+ }
|
|
|
+ return context;
|
|
|
}
|
|
|
};
|
|
|
|