Lutz Roeder 4 rokov pred
rodič
commit
390a225316
1 zmenil súbory, kde vykonal 395 pridanie a 489 odobranie
  1. 395 489
      source/tf.js

+ 395 - 489
source/tf.js

@@ -103,24 +103,29 @@ tf.ModelFactory = class {
                 }
                 if (saved_model.meta_graphs.every((meta_graph) => meta_graph.graph_def.node.every((node) => node.op.startsWith('aten::') || node.op.startsWith('prim::') || node.op === 'IO Node'))) {
                     producer = 'PyTorch';
-                    return context.request('pytorch-metadata.json', 'utf-8', null).then((data) => {
-                        const metadata = new Map();
-                        for (const item of JSON.parse(data)) {
-                            const index = item.name.indexOf(':');
-                            const key = (index !== -1) ? item.name.substring(0, index) : item.name;
-                            const name = key.replace(/^torch\./, 'aten::');
-                            if (!metadata.has(name)) {
-                                metadata.set(name, []);
+                    const openPyTorchMetadata = (context, saved_model) => {
+                        return context.request('pytorch-metadata.json', 'utf-8', null).then((data) => {
+                            const metadata = new Map();
+                            for (const item of JSON.parse(data)) {
+                                const index = item.name.indexOf(':');
+                                const key = (index !== -1) ? item.name.substring(0, index) : item.name;
+                                const name = key.replace(/^torch\./, 'aten::');
+                                if (!metadata.has(name)) {
+                                    metadata.set(name, []);
+                                }
+                                metadata.get(name).push(item);
                             }
-                            metadata.get(name).push(item);
-                        }
-                        for (const meta_graph of saved_model.meta_graphs) {
-                            for (const node of meta_graph.graph_def.node) {
-                                node.__metadata__ = Array.from(metadata.get(node.op) || []);
+                            for (const meta_graph of saved_model.meta_graphs) {
+                                for (const node of meta_graph.graph_def.node) {
+                                    node.__metadata__ = Array.from(metadata.get(node.op) || []);
+                                }
                             }
-                        }
-                        return openModel(saved_model, format, producer, null);
-                    }).catch(() => {
+                            return saved_model;
+                        }).catch(() => {
+                            return saved_model;
+                        });
+                    };
+                    return openPyTorchMetadata(context, saved_model).then((saved_model) => {
                         return openModel(saved_model, format, producer, null);
                     });
                 }
@@ -504,9 +509,9 @@ tf.Model = class {
             }
         }
         else {
-            this._graphs.push(new tf.Graph(metadata, null, '', bundle));
+            const graph = new tf.Graph(metadata, null, '', bundle);
+            this._graphs.push(graph);
         }
-
     }
 
     get format() {
@@ -529,11 +534,11 @@ tf.Model = class {
 tf.Graph = class {
 
     constructor(metadata, meta_graph, name, bundle) {
-        this._version = null;
         this._name = name;
         this._inputs = [];
         this._outputs = [];
         this._nodes = [];
+        this._version = null;
 
         if (meta_graph && meta_graph.graph_def) {
             const graph = meta_graph.graph_def;
@@ -549,359 +554,12 @@ tf.Graph = class {
             if (meta_graph.meta_info_def && meta_graph.meta_info_def.tags) {
                 this._tags = meta_graph.meta_info_def.tags.join(', ');
             }
-
             metadata = new tf.GraphMetadata(metadata, graph.library);
-
-            const nodes = graph.node;
-            if (nodes) {
-                const node_map = new Map();
-                const namespaces = new Set();
-                for (const node of nodes) {
-                    const nodeName = node.name;
-                    node_map.set(nodeName, node);
-                    if (node.op != 'Const') {
-                        const index = nodeName.lastIndexOf('/');
-                        if (index != -1) {
-                            const namespace = nodeName.substring(0, index);
-                            namespaces.add(namespace);
-                        }
-                    }
-                    node.output = [];
-                }
-                for (const node of nodes) {
-                    const inputs = node.input;
-                    node.input = [];
-                    node.controlDependencies = [];
-                    for (const input of inputs) {
-                        const split = input.split(':', 2);
-                        const input_name = split[0];
-                        const input_index = split.length == 1 ? 0 : parseInt(split[1]);
-                        const from_name = input_name.startsWith('^') ? input_name.substring(1) : input_name;
-                        const from = node_map.get(from_name);
-                        const output_name = input_index == 0 ? from_name : from_name + ':' + input_index.toString();
-                        const input_arg = from ? { name: output_name, from: from } : { name: output_name };
-                        if (input_name.startsWith('^')) {
-                            node.controlDependencies.push(input_arg);
-                        }
-                        else {
-                            node.input.push(input_arg);
-                        }
-                        if (from) {
-                            for (let i = from.output.length; i <= input_index; i++) {
-                                from.output.push({ name: i === 0 ? from_name : from_name + ':' + i.toString(), to: [] });
-                            }
-                            from.output[input_index].to.push(node);
-                        }
-                    }
-                }
-
-                const initializers = new Map();
-                const map_tensor = (name, node, kind) => {
-                    if (node && node.op === 'Const' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                        const value = node.attr.value;
-                        if (value && Object.prototype.hasOwnProperty.call(value, 'tensor')) {
-                            const tensor = new tf.Tensor(value.tensor, name, kind);
-                            return new tf.Argument(name, tensor.type, tensor);
-                        }
-                    }
-                    return null;
-                };
-                const map_resource = (name, node, tensor) => {
-                    if (node.op === 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                        const dtype = node.attr.dtype.type;
-                        if (dtype === tf.proto.tensorflow.DataType.DT_RESOURCE) {
-                            return new tf.Argument(name, null, tensor);
-                        }
-                    }
-                    return null;
-                };
-                for (const node of node_map.values()) {
-                    if (node.op === 'Identity' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                        const initializer = map_tensor(node.name, node.input[0].from, 'Identity Constant');
-                        if (initializer) {
-                            initializers.set(initializer.name, initializer);
-                            node_map.delete(initializer.name);
-                            node_map.delete(node.input[0].name);
-                        }
-                        const identity = node.input[0].from;
-                        if (identity && identity.op === 'Identity' && identity.input.length === 1 && identity.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                            const initializer = map_tensor(node.name, identity.input[0].from, 'Identity Constant');
-                            if (initializer) {
-                                initializers.set(initializer.name, initializer);
-                                node_map.delete(initializer.name);
-                                node_map.delete(initializer.name);
-                                node_map.delete(identity.name);
-                                node_map.delete(node.name);
-                            }
-                        }
-                    }
-                }
-                for (const node of node_map.values()) {
-                    const initializer = map_tensor(node.name, node, 'Const');
-                    if (initializer) {
-                        initializers.set(initializer.name, initializer);
-                        node_map.delete(node.name);
-                        node_map.delete(initializer.name);
-                    }
-                }
-                for (const node of node_map.values()) {
-                    if (node.op === 'ReadVariableOp' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                        if (node.attr && node.attr.dtype && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape) {
-                            const tensor = new tf.proto.tensorflow.TensorProto();
-                            tensor.dtype = node.attr.dtype.type;
-                            tensor.tensor_shape = node.attr._output_shapes.list.shape[0];
-                            const initializer = map_resource(node.name, node.input[0].from, new tf.Tensor(tensor, name, 'Resource Variable'));
-                            if (initializer) {
-                                initializers.set(initializer.name, initializer);
-                                node_map.delete(initializer.name);
-                                node_map.delete(node.input[0].name);
-                            }
-                        }
-                    }
-                }
-                const input_map = new Map();
-                for (const node of node_map.values()) {
-                    if (node.op == 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                        const dtype = node.attr.dtype;
-                        const shape = node.attr.shape;
-                        if (dtype && dtype.type && shape && shape.shape) {
-                            const name = node.name;
-                            const type = new tf.TensorType(dtype.type, shape.shape);
-                            const argument = new tf.Argument(name, type, null);
-                            input_map.set(name, new tf.Parameter(name, [ argument ]));
-                            node_map.delete(name);
-                        }
-                    }
-                }
-                const updatePyTorch = (node_map) => {
-                    for (const node of node_map.values()) {
-                        if (node.op === 'prim::Constant' && node.input.length === 0 && node.controlDependencies.length === 0 && node.attr && Object.keys(node.attr).length === 1 && node.attr.attr && node.attr.attr.s) {
-                            const value = tf.Utility.decodeText(node.attr.attr.s);
-                            const match = /{\s*value\s*:\s*(.*)\s*}/.exec(value);
-                            if (match) {
-                                node.value = match[1].trim();
-                            }
-                            const empty = /{\s*}/.exec(value);
-                            if (empty) {
-                                node.value = null;
-                            }
-                        }
-                        if (node.op === 'prim::GetAttr' && node.input.length === 1 && node.controlDependencies.length === 0 && node.attr && Object.keys(node.attr).length === 1 && node.attr.attr && node.attr.attr.s) {
-                            const value = tf.Utility.decodeText(node.attr.attr.s);
-                            const match = /{\s*name\s*:\s*([A-za-z0-9_]*)\s*}/.exec(value);
-                            if (match) {
-                                node.value = match[1].trim();
-                            }
-                        }
-                        if (node.op === 'IO Node' && node.controlDependencies.length === 0) {
-                            const shape = node.attr && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape ? node.attr._output_shapes.list.shape[0] : null;
-                            const type = shape ? new tf.TensorType('?', shape) : null;
-                            if (node.input.length === 0 && node.output.length === 1) {
-                                this._inputs.push(new tf.Parameter(node.name, [
-                                    new tf.Argument(node.output[0].name, type, null)
-                                ]));
-                                node_map.delete(node.name);
-                            }
-                            if (node.input.length === 1 && node.output.length === 0) {
-                                this._outputs.push(new tf.Parameter(node.name, [
-                                    new tf.Argument(node.input[0].name, type, null)
-                                ]));
-                                node_map.delete(node.name);
-                            }
-                        }
-                        if (Object.keys(node.attr).length === 2 &&
-                            node.attr.attr && node.attr.attr.s && node.attr._output_shapes) {
-                            const value = tf.Utility.decodeText(node.attr.attr.s);
-                            if (/\s*/.exec(value) || /{\s*}/.exec(value)) {
-                                node.attr = {};
-                                delete node._output_shapes;
-                            }
-                        }
-                    }
-                    const remove_input = (input, node) => {
-                        const from = input.from;
-                        if (from) {
-                            for (const output of from.output) {
-                                output.to = output.to.filter((to) => to !== node);
-                            }
-                            if (from.output.every((output) => output.to.length === 0) && from.controlDependencies.length === 0) {
-                                from.remove = true;
-                            }
-                            delete input.from;
-                        }
-                    };
-                    for (const node of node_map.values()) {
-                        if (node.op === 'prim::ListConstruct' && node.input.every((input) => input.from.value !== undefined) && node.controlDependencies.length === 0) {
-                            node.value = node.input.map((input) => input.from.value);
-                            for (const input of node.input) {
-                                remove_input(input, node);
-                            }
-                            node.input = [];
-                        }
-                    }
-                    for (const node of node_map.values()) {
-                        const remove = new Set();
-                        for (let i = 0; i < node.input.length; i++) {
-                            const input = node.input[i];
-                            const from = input.from;
-                            if (from) {
-                                if (from.op === 'prim::GetAttr' && from.input.length === 1 && from.output.length === 1 && from.controlDependencies.length === 0 && from.value !== undefined) {
-                                    remove_input(input, node);
-                                    input.label = from.value;
-                                    const tensor = new tf.Tensor(null, input.name, from.op);
-                                    const argument = new tf.Argument(input.name, null, tensor);
-                                    initializers.set(input.name, argument);
-                                }
-                                if (from.op === 'prim::Constant' && from.input.length === 0 && from.controlDependencies.length === 0 && from.value !== undefined) {
-                                    input.constant = from.value;
-                                    remove_input(input, node);
-                                    remove.add(input.name);
-                                }
-                                if (from.op === 'prim::ListConstruct' && from.output.length === 1 && from.controlDependencies.length === 0 && from.value !== undefined) {
-                                    input.list = from.value;
-                                    remove_input(input, node);
-                                    remove.add(input.name);
-                                }
-                            }
-                        }
-                        if (node.__metadata__) {
-                            for (const metadata of node.__metadata__) {
-                                const parameters = Array.prototype.slice.call(metadata.inputs || []).concat(Array.prototype.slice.call(metadata.attributes || []));
-                                let match = true;
-                                const inputs = Array.from(node.input);
-                                if (inputs.length > parameters.length) {
-                                    match = false;
-                                }
-                                while (inputs.length > 0 && match) {
-                                    match = false;
-                                    const input = inputs.shift();
-                                    delete input.metadata;
-                                    const parameter = parameters.shift();
-                                    switch (parameter.type) {
-                                        case 'Tensor': {
-                                            if ((input.constant === undefined && input.list === undefined) || input.constant === null) {
-                                                input.metadata = parameter;
-                                                match = true;
-                                            }
-                                            else {
-                                                inputs.unshift(input);
-                                                match = true;
-                                            }
-                                            break;
-                                        }
-                                        case 'int64': {
-                                            const value = parseInt(input.constant);
-                                            if (input.constant !== undefined && Number.isInteger(value)) {
-                                                input.attr = new tf.proto.tensorflow.AttrValue();
-                                                input.attr.i = value;
-                                                input.attr.metadata = parameter;
-                                                match = true;
-                                            }
-                                            break;
-                                        }
-                                        case 'float32': {
-                                            const value = parseFloat(input.constant);
-                                            if (input.constant !== undefined && !isNaN(value)) {
-                                                input.attr = new tf.proto.tensorflow.AttrValue();
-                                                input.attr.f = value;
-                                                input.attr.metadata = parameter;
-                                                match = true;
-                                            }
-                                            break;
-                                        }
-                                        case 'int64[]': {
-                                            if (Array.isArray(input.list)) {
-                                                const list = input.list.map((item) => parseInt(item));
-                                                if (list.every((value) => Number.isInteger(value))) {
-                                                    input.attr = new tf.proto.tensorflow.AttrValue();
-                                                    input.attr.list = new tf.proto.tensorflow.ListValue();
-                                                    input.attr.list.i = list;
-                                                    input.attr.metadata = parameter;
-                                                    match = true;
-                                                }
-                                            }
-                                            break;
-                                        }
-                                        case 'boolean': {
-                                            if (input.constant === 'false' || input.constant === '0') {
-                                                input.attr = new tf.proto.tensorflow.AttrValue();
-                                                input.attr.b = false;
-                                                input.attr.metadata = parameter;
-                                                match = true;
-                                            }
-                                            else if (input.constant === 'true' || input.constant === '1') {
-                                                input.attr = new tf.proto.tensorflow.AttrValue();
-                                                input.attr.b = true;
-                                                input.attr.metadata = parameter;
-                                                match = true;
-                                            }
-                                            break;
-                                        }
-                                        case 'Scalar': {
-                                            const value = parseInt(input.constant);
-                                            if (input.constant !== undefined && Number.isInteger(value)) {
-                                                input.attr = new tf.proto.tensorflow.AttrValue();
-                                                input.attr.i = value;
-                                                input.attr.metadata = parameter;
-                                                match = true;
-                                            }
-                                            break;
-                                        }
-                                        default:
-                                            break;
-                                    }
-                                }
-                                if (match) {
-                                    node.metadata = metadata;
-                                    break;
-                                }
-                                else {
-                                    for (const input of node.input) {
-                                        delete input.metadata;
-                                        delete input.attr;
-                                    }
-                                }
-                            }
-                        }
-                        node.input = node.input.filter((input, index) => {
-                            if (input.attr) {
-                                const name = input.attr.metadata ? input.attr.metadata.name : index.toString();
-                                node.attr[name] = input.attr;
-                            }
-                            else if (input.constant !== undefined && input.constant !== null) {
-                                const attr = new tf.proto.tensorflow.AttrValue();
-                                attr.s = input.constant;
-                                node.attr[index.toString()] = attr;
-                            }
-                            else if (input.list !== undefined) {
-                                const attr = new tf.proto.tensorflow.AttrValue();
-                                attr.list = new tf.proto.tensorflow.ListValue();
-                                attr.list.s = input.list;
-                                node.attr[index.toString()] = attr;
-                            }
-                            return !remove.has(input.name);
-                        });
-                    }
-                    for (const node of node_map.values()) {
-                        if (node.op === 'prim::GetAttr' && node.remove) {
-                            node_map.delete(node.name);
-                        }
-                        if (node.op === 'prim::Constant' && node.remove) {
-                            node_map.delete(node.name);
-                        }
-                        if (node.op === 'prim::ListConstruct' && node.remove) {
-                            node_map.delete(node.name);
-                        }
-                    }
-                };
-                updatePyTorch(node_map);
-                for (const input of input_map.values()) {
-                    this._inputs.push(input);
-                }
-                for (const node of node_map.values()) {
-                    this._nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
-                }
-            }
+            const nodes = graph.node || [];
+            const context = tf.Utility.createGraph(metadata, nodes);
+            this._nodes = context.nodes;
+            this._inputs = context.inputs;
+            this._outputs = context.outputs;
         }
         else if (bundle) {
             const nodeNames = [];
@@ -1053,124 +711,11 @@ tf.Function = class {
                 output_arg_map.set(name, output.name);
             }
         }
-
-        const namespaces = new Set();
-        const nodes = func.node_def;
-        if (nodes) {
-            const node_map = new Map();
-            for (const node of nodes) {
-                const nodeName = node.name;
-                node_map.set(nodeName, node);
-                if (node.op != 'Const') {
-                    const lastIndex = nodeName.lastIndexOf('/');
-                    if (lastIndex != -1) {
-                        const namespace = nodeName.substring(0, lastIndex);
-                        namespaces.add(namespace);
-                    }
-                }
-                node.output = [];
-            }
-            for (const node of nodes) {
-                const inputs = node.input;
-                node.input = [];
-                node.controlDependencies = [];
-                for (const input of inputs) {
-                    const split = input.split(':', 3);
-                    const input_name = split[0];
-                    const input_index = split.length == 1 ? 0 : parseInt(split[split.length - 1]);
-                    const from_name = input_name.startsWith('^') ? input_name.substring(1) : input_name;
-                    const from = node_map.get(from_name);
-                    const output_name = from_name + (input_index == 0 ? '' : ':' + input_index.toString());
-                    const input_arg = from ? { name: output_name, from: from } : { name: output_name };
-                    if (input_name.startsWith('^')) {
-                        node.controlDependencies.push(input_arg);
-                    }
-                    else {
-                        node.input.push(input_arg);
-                    }
-                    if (from) {
-                        for (let i = from.output.length; i <= input_index; i++) {
-                            from.output.push({ name: i === 0 ? from_name : from_name + ':' + i.toString(), to: [] });
-                        }
-                        from.output[input_index].to.push(node);
-                    }
-                }
-            }
-
-            for (const node of nodes) {
-                if (output_arg_map.has(node.name)) {
-                    node.output.push({ name: node.name, to: [] });
-                }
-            }
-
-            const initializers = new Map();
-            const map_tensor = (name, node, kind) => {
-                if (node && node.op === 'Const' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                    const value = node.attr.value;
-                    if (value && Object.prototype.hasOwnProperty.call(value, 'tensor')) {
-                        const tensor = new tf.Tensor(value.tensor, name, kind);
-                        return new tf.Argument(name, tensor.type, tensor);
-                    }
-                }
-                return null;
-            };
-            const map_resource = (name, node, tensor) => {
-                if (node && node.op === 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                    const dtype = node.attr.dtype.type;
-                    if (dtype === tf.proto.tensorflow.DataType.DT_RESOURCE) {
-                        return new tf.Argument(name, null, tensor);
-                    }
-                }
-                return null;
-            };
-            for (const node of node_map.values()) {
-                if (node.op === 'Identity' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                    const initializer = map_tensor(node.name, node.input[0].from, 'Identity Constant');
-                    if (initializer) {
-                        initializers.set(initializer.name, initializer);
-                        node_map.delete(initializer.name);
-                        node_map.delete(node.input[0].name);
-                    }
-                    const identity = node.input[0];
-                    if (identity.op === 'Identity' && identity.input.length === 1 && identity.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
-                        const initializer = map_tensor(node.name, identity.input[0], 'Identity Constant');
-                        if (initializer) {
-                            initializers.set(initializer.name, initializer);
-                            node_map.delete(initializer.name);
-                            node_map.delete(initializer.name);
-                            node_map.delete(identity.name);
-                            node_map.delete(node.name);
-                        }
-                    }
-                }
-            }
-            for (const node of node_map.values()) {
-                const initializer = map_tensor(node.name, node, 'Const');
-                if (initializer) {
-                    initializers.set(initializer.name, initializer);
-                    node_map.delete(node.name);
-                    node_map.delete(initializer.name);
-                }
-            }
-            for (const node of node_map.values()) {
-                if (node.op === 'ReadVariableOp' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0 &&
-                    node.attr && node.attr && node.attr.dtype && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape) {
-                    const tensor = new tf.proto.tensorflow.TensorProto();
-                    tensor.dtype = node.attr.dtype.type;
-                    tensor.tensor_shape = node.attr._output_shapes.list.shape[0];
-                    const initializer = map_resource(node.name, node.input[0], new tf.Tensor(tensor, name, 'Resource Variable'));
-                    if (initializer) {
-                        initializers.set(initializer.name, initializer);
-                        node_map.delete(initializer.name);
-                        node_map.delete(node.input[0].name);
-                    }
-                }
-            }
-
-            for (const node of node_map.values()) {
-                this._nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
-            }
-        }
+        const nodes = func.node_def || [];
+        const context = tf.Utility.createGraph(metadata, nodes, output_arg_map);
+        this._nodes = context.nodes;
+        this._inputs = this._inputs.concat(context.inputs);
+        this._outputs = this._outputs.concat(context.outputs);
     }
 
     get type() {
@@ -2381,7 +1926,368 @@ tf.Utility = class {
             tf.Utility._dataTypeKeys = dataTypeKeys;
         }
         return tf.Utility._dataTypeKeys.get(type);
+    }
 
+    static createGraph(metadata, nodes, output_arg_map) {
+        const context = {};
+        context.inputs = [];
+        context.outputs = [];
+        context.nodes = [];
+        const namespaces = new Set();
+        const node_map = new Map();
+        for (const node of nodes) {
+            const nodeName = node.name;
+            node_map.set(nodeName, node);
+            if (node.op != 'Const') {
+                const index = nodeName.lastIndexOf('/');
+                if (index != -1) {
+                    const namespace = nodeName.substring(0, index);
+                    namespaces.add(namespace);
+                }
+            }
+            node.output = [];
+        }
+        for (const node of nodes) {
+            const inputs = node.input;
+            node.input = [];
+            node.controlDependencies = [];
+            for (const input of inputs) {
+                const split = input.split(':', 3);
+                const input_name = split[0];
+                const input_index = split.length == 1 ? 0 : parseInt(split[split.length - 1]);
+                const from_name = input_name.startsWith('^') ? input_name.substring(1) : input_name;
+                const from = node_map.get(from_name);
+                const output_name = input_index == 0 ? from_name : from_name + ':' + input_index.toString();
+                const input_arg = from ? { name: output_name, from: from } : { name: output_name };
+                if (input_name.startsWith('^')) {
+                    node.controlDependencies.push(input_arg);
+                }
+                else {
+                    node.input.push(input_arg);
+                }
+                if (from) {
+                    for (let i = from.output.length; i <= input_index; i++) {
+                        from.output.push({ name: i === 0 ? from_name : from_name + ':' + i.toString(), to: [] });
+                    }
+                    from.output[input_index].to.push(node);
+                }
+            }
+        }
+        if (output_arg_map) {
+            for (const node of nodes) {
+                if (output_arg_map.has(node.name)) {
+                    node.output.push({ name: node.name, to: [] });
+                }
+            }
+        }
+        const initializers = new Map();
+        const map_tensor = (name, node, kind) => {
+            if (node && node.op === 'Const' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
+                const value = node.attr.value;
+                if (value && Object.prototype.hasOwnProperty.call(value, 'tensor')) {
+                    const tensor = new tf.Tensor(value.tensor, name, kind);
+                    return new tf.Argument(name, tensor.type, tensor);
+                }
+            }
+            return null;
+        };
+        const map_resource = (name, node, tensor) => {
+            if (node && node.op === 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
+                const dtype = node.attr.dtype.type;
+                if (dtype === tf.proto.tensorflow.DataType.DT_RESOURCE) {
+                    return new tf.Argument(name, null, tensor);
+                }
+            }
+            return null;
+        };
+        for (const node of node_map.values()) {
+            if (node.op === 'Identity' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
+                const initializer = map_tensor(node.name, node.input[0].from, 'Identity Constant');
+                if (initializer) {
+                    initializers.set(initializer.name, initializer);
+                    node_map.delete(initializer.name);
+                    node_map.delete(node.input[0].name);
+                }
+                const identity = node.input[0].from;
+                if (identity && identity.op === 'Identity' && identity.input.length === 1 && identity.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
+                    const initializer = map_tensor(node.name, identity.input[0].from, 'Identity Constant');
+                    if (initializer) {
+                        initializers.set(initializer.name, initializer);
+                        node_map.delete(initializer.name);
+                        node_map.delete(initializer.name);
+                        node_map.delete(identity.name);
+                        node_map.delete(node.name);
+                    }
+                }
+            }
+        }
+        for (const node of node_map.values()) {
+            const initializer = map_tensor(node.name, node, 'Const');
+            if (initializer) {
+                initializers.set(initializer.name, initializer);
+                node_map.delete(node.name);
+                node_map.delete(initializer.name);
+            }
+        }
+        for (const node of node_map.values()) {
+            if (node.op === 'ReadVariableOp' && node.input.length === 1 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
+                if (node.attr && node.attr.dtype && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape) {
+                    const tensor = new tf.proto.tensorflow.TensorProto();
+                    tensor.dtype = node.attr.dtype.type;
+                    tensor.tensor_shape = node.attr._output_shapes.list.shape[0];
+                    const name = node.name;
+                    const initializer = map_resource(name, node.input[0].from,  new tf.Tensor(tensor, name, 'Resource Variable'));
+                    if (initializer) {
+                        initializers.set(initializer.name, initializer);
+                        node_map.delete(initializer.name);
+                        node_map.delete(node.input[0].name);
+                    }
+                }
+            }
+        }
+        const input_map = new Map();
+        for (const node of node_map.values()) {
+            if (node.op == 'Placeholder' && node.input.length === 0 && node.output.length === 1 && node.output[0].to.length === 1 && node.controlDependencies.length === 0) {
+                const dtype = node.attr.dtype;
+                const shape = node.attr.shape;
+                if (dtype && dtype.type && shape && shape.shape) {
+                    const name = node.name;
+                    const type = new tf.TensorType(dtype.type, shape.shape);
+                    const argument = new tf.Argument(name, type, null);
+                    input_map.set(name, new tf.Parameter(name, [ argument ]));
+                    node_map.delete(name);
+                }
+            }
+        }
+        const updatePyTorch = (node_map) => {
+            for (const node of node_map.values()) {
+                if (node.op === 'prim::Constant' && node.input.length === 0 && node.controlDependencies.length === 0 && node.attr && Object.keys(node.attr).length === 1 && node.attr.attr && node.attr.attr.s) {
+                    const value = tf.Utility.decodeText(node.attr.attr.s);
+                    const match = /{\s*value\s*:\s*(.*)\s*}/.exec(value);
+                    if (match) {
+                        node.value = match[1].trim();
+                    }
+                    const empty = /{\s*}/.exec(value);
+                    if (empty) {
+                        node.value = null;
+                    }
+                }
+                if (node.op === 'prim::GetAttr' && node.input.length === 1 && node.controlDependencies.length === 0 && node.attr && Object.keys(node.attr).length === 1 && node.attr.attr && node.attr.attr.s) {
+                    const value = tf.Utility.decodeText(node.attr.attr.s);
+                    const match = /{\s*name\s*:\s*([A-za-z0-9_]*)\s*}/.exec(value);
+                    if (match) {
+                        node.value = match[1].trim();
+                    }
+                }
+                if (node.op === 'IO Node' && node.controlDependencies.length === 0) {
+                    const shape = node.attr && node.attr._output_shapes && node.attr._output_shapes.list && node.attr._output_shapes.list.shape ? node.attr._output_shapes.list.shape[0] : null;
+                    const type = shape ? new tf.TensorType('?', shape) : null;
+                    if (node.input.length === 0 && node.output.length === 1) {
+                        context.inputs.push(new tf.Parameter(node.name, [
+                            new tf.Argument(node.output[0].name, type, null)
+                        ]));
+                        node_map.delete(node.name);
+                    }
+                    if (node.input.length === 1 && node.output.length === 0) {
+                        context.outputs.push(new tf.Parameter(node.name, [
+                            new tf.Argument(node.input[0].name, type, null)
+                        ]));
+                        node_map.delete(node.name);
+                    }
+                }
+                if (Object.keys(node.attr).length === 2 &&
+                    node.attr.attr && node.attr.attr.s && node.attr._output_shapes) {
+                    const value = tf.Utility.decodeText(node.attr.attr.s);
+                    if (/\s*/.exec(value) || /{\s*}/.exec(value)) {
+                        node.attr = {};
+                        delete node._output_shapes;
+                    }
+                }
+            }
+            const remove_input = (input, node) => {
+                const from = input.from;
+                if (from) {
+                    for (const output of from.output) {
+                        output.to = output.to.filter((to) => to !== node);
+                    }
+                    if (from.output.every((output) => output.to.length === 0) && from.controlDependencies.length === 0) {
+                        from.remove = true;
+                    }
+                    delete input.from;
+                }
+            };
+            for (const node of node_map.values()) {
+                if (node.op === 'prim::ListConstruct' && node.input.every((input) => input.from.value !== undefined) && node.controlDependencies.length === 0) {
+                    node.value = node.input.map((input) => input.from.value);
+                    for (const input of node.input) {
+                        remove_input(input, node);
+                    }
+                    node.input = [];
+                }
+            }
+            for (const node of node_map.values()) {
+                const remove = new Set();
+                for (let i = 0; i < node.input.length; i++) {
+                    const input = node.input[i];
+                    const from = input.from;
+                    if (from) {
+                        if (from.op === 'prim::GetAttr' && from.input.length === 1 && from.output.length === 1 && from.controlDependencies.length === 0 && from.value !== undefined) {
+                            remove_input(input, node);
+                            input.label = from.value;
+                            const tensor = new tf.Tensor(null, input.name, from.op);
+                            const argument = new tf.Argument(input.name, null, tensor);
+                            initializers.set(input.name, argument);
+                        }
+                        if (from.op === 'prim::Constant' && from.input.length === 0 && from.controlDependencies.length === 0 && from.value !== undefined) {
+                            input.constant = from.value;
+                            remove_input(input, node);
+                            remove.add(input.name);
+                        }
+                        if (from.op === 'prim::ListConstruct' && from.output.length === 1 && from.controlDependencies.length === 0 && from.value !== undefined) {
+                            input.list = from.value;
+                            remove_input(input, node);
+                            remove.add(input.name);
+                        }
+                    }
+                }
+                if (node.__metadata__) {
+                    for (const metadata of node.__metadata__) {
+                        const parameters = Array.prototype.slice.call(metadata.inputs || []).concat(Array.prototype.slice.call(metadata.attributes || []));
+                        let match = true;
+                        const inputs = Array.from(node.input);
+                        if (inputs.length > parameters.length) {
+                            match = false;
+                        }
+                        while (inputs.length > 0 && match) {
+                            match = false;
+                            const input = inputs.shift();
+                            delete input.metadata;
+                            const parameter = parameters.shift();
+                            switch (parameter.type) {
+                                case 'Tensor': {
+                                    if ((input.constant === undefined && input.list === undefined) || input.constant === null) {
+                                        input.metadata = parameter;
+                                        match = true;
+                                    }
+                                    else {
+                                        inputs.unshift(input);
+                                        match = true;
+                                    }
+                                    break;
+                                }
+                                case 'int64': {
+                                    const value = parseInt(input.constant);
+                                    if (input.constant !== undefined && Number.isInteger(value)) {
+                                        input.attr = new tf.proto.tensorflow.AttrValue();
+                                        input.attr.i = value;
+                                        input.attr.metadata = parameter;
+                                        match = true;
+                                    }
+                                    break;
+                                }
+                                case 'float32': {
+                                    const value = parseFloat(input.constant);
+                                    if (input.constant !== undefined && !isNaN(value)) {
+                                        input.attr = new tf.proto.tensorflow.AttrValue();
+                                        input.attr.f = value;
+                                        input.attr.metadata = parameter;
+                                        match = true;
+                                    }
+                                    break;
+                                }
+                                case 'int64[]': {
+                                    if (Array.isArray(input.list)) {
+                                        const list = input.list.map((item) => parseInt(item));
+                                        if (list.every((value) => Number.isInteger(value))) {
+                                            input.attr = new tf.proto.tensorflow.AttrValue();
+                                            input.attr.list = new tf.proto.tensorflow.ListValue();
+                                            input.attr.list.i = list;
+                                            input.attr.metadata = parameter;
+                                            match = true;
+                                        }
+                                    }
+                                    break;
+                                }
+                                case 'boolean': {
+                                    if (input.constant === 'false' || input.constant === '0') {
+                                        input.attr = new tf.proto.tensorflow.AttrValue();
+                                        input.attr.b = false;
+                                        input.attr.metadata = parameter;
+                                        match = true;
+                                    }
+                                    else if (input.constant === 'true' || input.constant === '1') {
+                                        input.attr = new tf.proto.tensorflow.AttrValue();
+                                        input.attr.b = true;
+                                        input.attr.metadata = parameter;
+                                        match = true;
+                                    }
+                                    break;
+                                }
+                                case 'Scalar': {
+                                    const value = parseInt(input.constant);
+                                    if (input.constant !== undefined && Number.isInteger(value)) {
+                                        input.attr = new tf.proto.tensorflow.AttrValue();
+                                        input.attr.i = value;
+                                        input.attr.metadata = parameter;
+                                        match = true;
+                                    }
+                                    break;
+                                }
+                                default:
+                                    break;
+                            }
+                        }
+                        if (match) {
+                            node.metadata = metadata;
+                            break;
+                        }
+                        else {
+                            for (const input of node.input) {
+                                delete input.metadata;
+                                delete input.attr;
+                            }
+                        }
+                    }
+                }
+                node.input = node.input.filter((input, index) => {
+                    if (input.attr) {
+                        const name = input.attr.metadata ? input.attr.metadata.name : index.toString();
+                        node.attr[name] = input.attr;
+                    }
+                    else if (input.constant !== undefined && input.constant !== null) {
+                        const attr = new tf.proto.tensorflow.AttrValue();
+                        attr.s = input.constant;
+                        node.attr[index.toString()] = attr;
+                    }
+                    else if (input.list !== undefined) {
+                        const attr = new tf.proto.tensorflow.AttrValue();
+                        attr.list = new tf.proto.tensorflow.ListValue();
+                        attr.list.s = input.list;
+                        node.attr[index.toString()] = attr;
+                    }
+                    return !remove.has(input.name);
+                });
+            }
+            for (const node of node_map.values()) {
+                if (node.op === 'prim::GetAttr' && node.remove) {
+                    node_map.delete(node.name);
+                }
+                if (node.op === 'prim::Constant' && node.remove) {
+                    node_map.delete(node.name);
+                }
+                if (node.op === 'prim::ListConstruct' && node.remove) {
+                    node_map.delete(node.name);
+                }
+            }
+        };
+        updatePyTorch(node_map);
+        for (const input of input_map.values()) {
+            context.inputs.push(input);
+        }
+        for (const node of node_map.values()) {
+            context.nodes.push(new tf.Node(metadata, namespaces, node, node.op, node.name, initializers, null));
+        }
+        return context;
     }
 };