|
|
@@ -11,9 +11,9 @@ const numpy = {};
|
|
|
pytorch.ModelFactory = class {
|
|
|
|
|
|
async match(context) {
|
|
|
- const container = await pytorch.Container.open(context);
|
|
|
- if (container) {
|
|
|
- return context.set(container.type, container);
|
|
|
+ const reader = await pytorch.Reader.open(context);
|
|
|
+ if (reader) {
|
|
|
+ return context.set(reader.type, reader);
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
@@ -42,7 +42,7 @@ pytorch.ModelFactory = class {
|
|
|
});
|
|
|
await target.read(metadata);
|
|
|
if (!target.format || (!target.modules && !target.module)) {
|
|
|
- throw new pytorch.Error("Container not implemented.");
|
|
|
+ throw new pytorch.Error("Reader not implemented.");
|
|
|
}
|
|
|
return new pytorch.Model(metadata, target);
|
|
|
}
|
|
|
@@ -76,17 +76,17 @@ pytorch.Graph = class {
|
|
|
this.outputs = [];
|
|
|
this.name = name;
|
|
|
this.type = type;
|
|
|
- const values = new Map();
|
|
|
- values.map = (name, type, tensor) => {
|
|
|
+ const context = new pytorch.Context(execution, metadata);
|
|
|
+ context.values.map = (name, type, tensor) => {
|
|
|
if (tensor) {
|
|
|
return new pytorch.Value(name, type, null, tensor);
|
|
|
}
|
|
|
- if (!values.has(name)) {
|
|
|
- values.set(name, new pytorch.Value(name, type, null, tensor));
|
|
|
+ if (!context.values.has(name)) {
|
|
|
+ context.values.set(name, new pytorch.Value(name, type, null, tensor));
|
|
|
} else if (type || tensor) {
|
|
|
throw new pytorch.Error(`Duplicate value '${name}'.`);
|
|
|
}
|
|
|
- return values.get(name);
|
|
|
+ return context.values.get(name);
|
|
|
};
|
|
|
const torch = execution ? execution.torch : null;
|
|
|
if (torch && module instanceof torch.jit._script.RecursiveScriptModule && module._c._has_method('forward')) {
|
|
|
@@ -188,12 +188,12 @@ pytorch.Graph = class {
|
|
|
}
|
|
|
const identifier = pytorch.Utility.unique(v);
|
|
|
const name = v.debugName() || identifier;
|
|
|
- const value = values.map(identifier);
|
|
|
+ const value = context.values.map(identifier);
|
|
|
this.inputs.push(new pytorch.Argument(name, [value]));
|
|
|
}
|
|
|
for (const value of graph.outputs()) {
|
|
|
const identifier = pytorch.Utility.unique(value);
|
|
|
- this.outputs.push(new pytorch.Argument(identifier, [values.map(identifier)]));
|
|
|
+ this.outputs.push(new pytorch.Argument(identifier, [context.values.map(identifier)]));
|
|
|
}
|
|
|
for (const node of graph.nodes()) {
|
|
|
if (deleted.has(node)) {
|
|
|
@@ -210,29 +210,15 @@ pytorch.Graph = class {
|
|
|
continue;
|
|
|
}
|
|
|
}
|
|
|
- this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, values));
|
|
|
+ this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, context));
|
|
|
}
|
|
|
} else if (torch && module instanceof torch.export.exported_program.ExportedProgram && module.graph) {
|
|
|
const exported_program = module;
|
|
|
const graph = exported_program.graph;
|
|
|
+ const graph_module = exported_program.graph_module;
|
|
|
const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters;
|
|
|
const inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers;
|
|
|
const inputs_to_lifted_tensor_constants = exported_program.graph_signature.inputs_to_lifted_tensor_constants;
|
|
|
- const values = new Map();
|
|
|
- values.map = (obj) => {
|
|
|
- if (!values.has(obj)) {
|
|
|
- let type = null;
|
|
|
- const val = obj.meta.get('val');
|
|
|
- if (val && val.dtype) {
|
|
|
- const dataType = val.dtype.__reduce__();
|
|
|
- const shape = new pytorch.TensorShape(val.shape);
|
|
|
- type = new pytorch.TensorType(dataType, shape);
|
|
|
- }
|
|
|
- const value = new pytorch.Value(obj.name, type);
|
|
|
- values.set(obj, value);
|
|
|
- }
|
|
|
- return values.get(obj);
|
|
|
- };
|
|
|
const nodes = new Map(graph.nodes.map((node) => [node.name, node]));
|
|
|
for (const obj of graph.nodes) {
|
|
|
if (obj.op === 'placeholder') {
|
|
|
@@ -242,60 +228,40 @@ pytorch.Graph = class {
|
|
|
const tensor = parameter && parameter.data ? parameter.data : obj.meta.get('val');
|
|
|
const initializer = new pytorch.Tensor(key, tensor);
|
|
|
const value = new pytorch.Value(key, null, null, initializer);
|
|
|
- values.set(obj, value);
|
|
|
+ context.values.set(obj, value);
|
|
|
} else if (inputs_to_buffers.has(obj.name)) {
|
|
|
const key = inputs_to_buffers.get(obj.name);
|
|
|
const buffer = exported_program.state_dict.get(key);
|
|
|
const tensor = buffer || obj.meta.get('val');
|
|
|
const initializer = new pytorch.Tensor(key, tensor);
|
|
|
const value = new pytorch.Value(key, null, null, initializer);
|
|
|
- values.set(obj, value);
|
|
|
+ context.values.set(obj, value);
|
|
|
} else if (inputs_to_lifted_tensor_constants.has(obj.name)) {
|
|
|
const key = inputs_to_lifted_tensor_constants.get(obj.name);
|
|
|
const constant = exported_program.constants.get(key);
|
|
|
const tensor = constant && constant.data ? constant.data : obj.meta.get('val');
|
|
|
const initializer = new pytorch.Tensor(key, tensor);
|
|
|
const value = new pytorch.Value(key, null, null, initializer);
|
|
|
- values.set(obj, value);
|
|
|
+ context.values.set(obj, value);
|
|
|
}
|
|
|
- if (obj.users.size > 1 && values.has(obj)) {
|
|
|
- const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, values);
|
|
|
+ if (obj.users.size > 1 && context.values.has(obj)) {
|
|
|
+ const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, context);
|
|
|
this.nodes.push(node);
|
|
|
- values.set(obj, node.outputs[0].value[0]);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- for (const obj of graph.nodes) {
|
|
|
- if (obj.op === 'placeholder') {
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (obj.op === 'call_function') {
|
|
|
- if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
|
|
|
- continue;
|
|
|
- }
|
|
|
- }
|
|
|
- if (obj.op === 'output') {
|
|
|
- for (const output of obj.args) {
|
|
|
- if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') {
|
|
|
- continue;
|
|
|
- }
|
|
|
- const value = values.map(output);
|
|
|
- const argument = new pytorch.Argument(output.name, [value]);
|
|
|
- this.outputs.push(argument);
|
|
|
+ context.values.set(obj, node.outputs[0].value[0]);
|
|
|
}
|
|
|
- continue;
|
|
|
}
|
|
|
- const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, values);
|
|
|
- this.nodes.push(node);
|
|
|
}
|
|
|
+ context.graph(this, graph_module, false);
|
|
|
for (const input_spec of exported_program.graph_signature.user_inputs) {
|
|
|
if (nodes.has(input_spec)) {
|
|
|
const node = nodes.get(input_spec);
|
|
|
- const value = values.map(node);
|
|
|
+ const value = context.value(node);
|
|
|
const argument = new pytorch.Argument(input_spec, [value]);
|
|
|
this.inputs.push(argument);
|
|
|
}
|
|
|
}
|
|
|
+ } else if (torch && module instanceof torch.fx.GraphModule && module.graph) {
|
|
|
+ context.graph(this, module, true);
|
|
|
} else if (pytorch.Utility.isTensor(module)) {
|
|
|
const node = new pytorch.Node(execution, metadata, null, type, { value: module });
|
|
|
this.nodes.push(node);
|
|
|
@@ -311,7 +277,7 @@ pytorch.Graph = class {
|
|
|
const modules = Array.isArray(module) && module.every((module) => module && !pytorch.Utility.isTensor(module) && (module._modules !== undefined || module.__class__)) ? module : [module];
|
|
|
for (const module of modules) {
|
|
|
const type = this.type === 'weights' ? 'Weights' : null;
|
|
|
- const node = new pytorch.Node(execution, metadata, null, type, module, null, values);
|
|
|
+ const node = new pytorch.Node(execution, metadata, null, type, module, null, context);
|
|
|
this.nodes.push(node);
|
|
|
}
|
|
|
}
|
|
|
@@ -344,7 +310,7 @@ pytorch.Value = class Value {
|
|
|
|
|
|
pytorch.Node = class {
|
|
|
|
|
|
- constructor(execution, metadata, name, type, obj, initializers, values, stack) {
|
|
|
+ constructor(execution, metadata, name, type, obj, initializers, context, stack) {
|
|
|
const torch = execution ? execution.torch : null;
|
|
|
const builtins = execution ? execution.builtins : null;
|
|
|
this.name = name || '';
|
|
|
@@ -390,11 +356,11 @@ pytorch.Node = class {
|
|
|
const mapTensor = (value) => {
|
|
|
if (value.identifier && pytorch.Utility.isTensor(value.value)) {
|
|
|
const identifier = value.identifier;
|
|
|
- if (!values.has(identifier)) {
|
|
|
+ if (!context.values.has(identifier)) {
|
|
|
const tensor = new pytorch.Tensor(identifier, value.value);
|
|
|
- values.set(identifier, new pytorch.Value(identifier, null, null, tensor));
|
|
|
+ context.values.set(identifier, new pytorch.Value(identifier, null, null, tensor));
|
|
|
}
|
|
|
- return values.map(identifier);
|
|
|
+ return context.values.map(identifier);
|
|
|
}
|
|
|
let initializer = null;
|
|
|
let identifier = value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
|
|
|
@@ -407,7 +373,7 @@ pytorch.Node = class {
|
|
|
if (initializer) {
|
|
|
return new pytorch.Value(identifier, null, null, initializer);
|
|
|
}
|
|
|
- return values.map(identifier);
|
|
|
+ return context.values.map(identifier);
|
|
|
};
|
|
|
for (let i = 0; i < inputs.length; i++) {
|
|
|
const input = inputs[i];
|
|
|
@@ -423,20 +389,20 @@ pytorch.Node = class {
|
|
|
if (type && type instanceof torch.ClassType) {
|
|
|
const obj = input.value;
|
|
|
if (!array && initializers.has(obj)) {
|
|
|
- const node = new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, values);
|
|
|
+ const node = new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context);
|
|
|
argument = new pytorch.Argument(name, node, 'object');
|
|
|
} else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
|
|
|
- const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, values));
|
|
|
+ const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context));
|
|
|
argument = new pytorch.Argument(name, node, 'object[]');
|
|
|
} else if (array && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((input) => input.value)) {
|
|
|
- const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, values));
|
|
|
+ const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, context));
|
|
|
argument = new pytorch.Argument(name, node, 'object[]');
|
|
|
} else if (input.value === undefined) {
|
|
|
const identifier = pytorch.Utility.unique(input);
|
|
|
- const value = values.map(identifier);
|
|
|
+ const value = context.values.map(identifier);
|
|
|
argument = new pytorch.Argument(name, [value]);
|
|
|
} else {
|
|
|
- const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, values);
|
|
|
+ const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, context);
|
|
|
argument = new pytorch.Argument(name, node, 'object');
|
|
|
}
|
|
|
} else if ((input.type() instanceof torch.TensorType || (input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.TensorType)) && pytorch.Utility.isTensor(input.value)) {
|
|
|
@@ -465,13 +431,13 @@ pytorch.Node = class {
|
|
|
return value.value;
|
|
|
}
|
|
|
const identifier = pytorch.Utility.unique(value);
|
|
|
- return values.map(identifier);
|
|
|
+ return context.values.map(identifier);
|
|
|
});
|
|
|
const type = list.every((value) => (pytorch.Utility.isTensor(value.value)) || value.value === null) ? null : pytorch.Utility.toType(input.type());
|
|
|
argument = new pytorch.Argument(name, args, type);
|
|
|
} else {
|
|
|
const identifier = pytorch.Utility.unique(input);
|
|
|
- argument = new pytorch.Argument(name, [values.map(identifier)]);
|
|
|
+ argument = new pytorch.Argument(name, [context.values.map(identifier)]);
|
|
|
}
|
|
|
} else if (input.type() instanceof torch.StringType && typeof input.value === 'string') {
|
|
|
argument = new pytorch.Argument(name, input.value, 'string');
|
|
|
@@ -485,7 +451,7 @@ pytorch.Node = class {
|
|
|
argument = new pytorch.Argument(name, null, 'attribute');
|
|
|
} else {
|
|
|
const identifier = pytorch.Utility.unique(input);
|
|
|
- const value = values.map(identifier);
|
|
|
+ const value = context.values.map(identifier);
|
|
|
argument = new pytorch.Argument(name, [value]);
|
|
|
}
|
|
|
} else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) {
|
|
|
@@ -508,14 +474,14 @@ pytorch.Node = class {
|
|
|
if (initializer) {
|
|
|
return new pytorch.Value(identifier, null, null, initializer);
|
|
|
}
|
|
|
- return values.map(identifier);
|
|
|
+ return context.values.map(identifier);
|
|
|
});
|
|
|
argument = new pytorch.Argument(name, args);
|
|
|
} else if (Array.isArray(input.value) && input.value.some((value) => value instanceof torch.Value)) {
|
|
|
const args = input.value.map((value) => {
|
|
|
if (value instanceof torch.Value) {
|
|
|
const identifier = pytorch.Utility.unique(value);
|
|
|
- return values.map(identifier);
|
|
|
+ return context.values.map(identifier);
|
|
|
}
|
|
|
return value;
|
|
|
});
|
|
|
@@ -540,7 +506,7 @@ pytorch.Node = class {
|
|
|
output.uses()[0].user.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
|
|
|
list = output.uses()[0].user.outputs();
|
|
|
}
|
|
|
- const args = list.map((output) => values.map(pytorch.Utility.unique(output)));
|
|
|
+ const args = list.map((output) => context.values.map(pytorch.Utility.unique(output)));
|
|
|
const argument = new pytorch.Argument(name, args);
|
|
|
this.outputs.push(argument);
|
|
|
}
|
|
|
@@ -611,15 +577,24 @@ pytorch.Node = class {
|
|
|
for (const [name, arg] of args) {
|
|
|
const type = inputs.has(name) ? pytorch.Utility.toType(inputs.get(name).real_type) : null;
|
|
|
if (arg instanceof torch.fx.node.Node) {
|
|
|
- const value = values.map(arg);
|
|
|
- const argument = new pytorch.Argument(name, [value]);
|
|
|
+ let argument = null;
|
|
|
+ if (arg.op === 'get_attr' && arg.users.size === 1) {
|
|
|
+ const subgraph = context.function(arg);
|
|
|
+ if (subgraph) {
|
|
|
+ argument = new pytorch.Argument(name, subgraph, 'function');
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (!argument) {
|
|
|
+ const value = context.value(arg);
|
|
|
+ argument = new pytorch.Argument(name, [value]);
|
|
|
+ }
|
|
|
this.inputs.push(argument);
|
|
|
} else if (Array.isArray(arg) && arg.every((arg) => arg instanceof torch.fx.node.Node || arg === null)) {
|
|
|
- const list = arg.map((arg) => arg === null ? null : values.map(arg));
|
|
|
+ const list = arg.map((arg) => arg === null ? null : context.value(arg));
|
|
|
const argument = new pytorch.Argument(name, list);
|
|
|
this.inputs.push(argument);
|
|
|
} else if (Array.isArray(arg)) {
|
|
|
- const list = arg.map((arg) => arg instanceof torch.fx.node.Node ? values.map(arg) : arg);
|
|
|
+ const list = arg.map((arg) => arg instanceof torch.fx.node.Node ? context.value(arg) : arg);
|
|
|
const argument = new pytorch.Argument(name, list, type || 'attribute');
|
|
|
this.inputs.push(argument);
|
|
|
} else if (arg instanceof torch.dtype || arg instanceof torch.device || arg instanceof torch.layout || arg instanceof torch.memory_format) {
|
|
|
@@ -643,7 +618,7 @@ pytorch.Node = class {
|
|
|
}
|
|
|
for (let i = 0; i < outputs.length; i++) {
|
|
|
const node = outputs[i];
|
|
|
- const value = values.map(node);
|
|
|
+ const value = context.value(node);
|
|
|
const name = schema && schema.returns && schema.returns[i] ? schema.returns[i].name || 'output' : 'output';
|
|
|
const argument = new pytorch.Argument(name, [value]);
|
|
|
this.outputs.push(argument);
|
|
|
@@ -671,19 +646,26 @@ pytorch.Node = class {
|
|
|
} else if (obj.op === 'placeholder') {
|
|
|
this.type = { name: obj.op };
|
|
|
{
|
|
|
- const value = values.map(obj);
|
|
|
+ const value = context.value(obj);
|
|
|
const argument = new pytorch.Argument('value', [value]);
|
|
|
this.inputs.push(argument);
|
|
|
}
|
|
|
{
|
|
|
- const value = values.map({ name: obj.name, meta: obj.meta });
|
|
|
+ const node = new torch.fx.node.Node(null, obj.name);
|
|
|
+ node.meta = obj.meta;
|
|
|
+ const value = context.value(node);
|
|
|
const argument = new pytorch.Argument('value', [value]);
|
|
|
this.outputs.push(argument);
|
|
|
}
|
|
|
} else if (obj.op === 'get_attr') {
|
|
|
this.type = { name: obj.op };
|
|
|
- this.inputs.push(new pytorch.Argument('name', obj.target, 'string'));
|
|
|
- const value = values.map(obj);
|
|
|
+ const subgraph = context.function(obj);
|
|
|
+ if (subgraph) {
|
|
|
+ this.inputs.push(new pytorch.Argument('name', subgraph, 'function'));
|
|
|
+ } else {
|
|
|
+ this.inputs.push(new pytorch.Argument('name', obj.target, 'string'));
|
|
|
+ }
|
|
|
+ const value = context.value(obj);
|
|
|
this.outputs.push(new pytorch.Argument('value', [value]));
|
|
|
} else if (obj.op === 'root') {
|
|
|
this.type = { name: obj.op };
|
|
|
@@ -783,7 +765,7 @@ pytorch.Node = class {
|
|
|
const argument = new pytorch.Argument(name, args, null, visible);
|
|
|
this.inputs.push(argument);
|
|
|
if (value && value.__variable__) {
|
|
|
- const argument = new pytorch.Argument(name, [values.map(value.__variable__)]);
|
|
|
+ const argument = new pytorch.Argument(name, [context.values.map(value.__variable__)]);
|
|
|
this.outputs.push(argument);
|
|
|
}
|
|
|
}
|
|
|
@@ -816,7 +798,7 @@ pytorch.Node = class {
|
|
|
const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
|
|
|
stack.add(value);
|
|
|
const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`;
|
|
|
- const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, values, stack);
|
|
|
+ const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, context, stack);
|
|
|
stack.delete(value);
|
|
|
return node;
|
|
|
});
|
|
|
@@ -829,7 +811,7 @@ pytorch.Node = class {
|
|
|
const list = value.filter((value) => !stack.has(value));
|
|
|
const nodes = list.map((value) => {
|
|
|
stack.add(value);
|
|
|
- const node = new pytorch.Node(execution, metadata, null, null, value, initializers, values, stack);
|
|
|
+ const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
|
|
|
stack.delete(value);
|
|
|
return node;
|
|
|
});
|
|
|
@@ -837,7 +819,7 @@ pytorch.Node = class {
|
|
|
this.inputs.push(argument);
|
|
|
} else if (value && (value.__class__ || typeof value === 'object') && !stack.has(value)) {
|
|
|
stack.add(value);
|
|
|
- const node = new pytorch.Node(execution, metadata, null, null, value, initializers, values, stack);
|
|
|
+ const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
|
|
|
stack.delete(value);
|
|
|
const visible = name !== '_metadata' || !pytorch.Utility.isMetadataObject(value);
|
|
|
const argument = new pytorch.Argument(name, node, 'object', visible);
|
|
|
@@ -994,27 +976,119 @@ pytorch.TensorShape = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container = class {
|
|
|
+pytorch.Context = class {
|
|
|
+
|
|
|
+ constructor(execution, metadata) {
|
|
|
+ this.execution = execution;
|
|
|
+ this.torch = execution ? execution.__import__('torch') : null;
|
|
|
+ this.metadata = metadata;
|
|
|
+ this.values = new Map();
|
|
|
+ this.modules = new Map();
|
|
|
+ }
|
|
|
+
|
|
|
+ value(obj) {
|
|
|
+ const torch = this.torch;
|
|
|
+ if (obj instanceof torch.fx.node.Node) {
|
|
|
+ if (!this.values.has(obj)) {
|
|
|
+ let type = null;
|
|
|
+ const val = obj.meta ? obj.meta.get('val') : null;
|
|
|
+ if (val && val.dtype) {
|
|
|
+ const dataType = val.dtype.__reduce__();
|
|
|
+ const shape = new pytorch.TensorShape(val.shape);
|
|
|
+ type = new pytorch.TensorType(dataType, shape);
|
|
|
+ }
|
|
|
+ const value = new pytorch.Value(obj.name, type);
|
|
|
+ this.values.set(obj, value);
|
|
|
+ }
|
|
|
+ return this.values.get(obj);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ function(obj) {
|
|
|
+ const torch = this.torch;
|
|
|
+ if (obj instanceof torch.fx.node.Node) {
|
|
|
+ let subgraph = this.modules.get(obj);
|
|
|
+ if (subgraph) {
|
|
|
+ if (subgraph instanceof pytorch.Graph === false) {
|
|
|
+ subgraph = new pytorch.Graph(this.execution, this.metadata, 'function', obj.target, subgraph);
|
|
|
+ this.modules.set(obj, subgraph);
|
|
|
+ }
|
|
|
+ return subgraph;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ graph(target, module, inputs) {
|
|
|
+ const graph = module.graph;
|
|
|
+ if (module.named_modules) {
|
|
|
+ const modules = module.named_modules();
|
|
|
+ for (const obj of graph.nodes) {
|
|
|
+ if (obj.op === 'get_attr') {
|
|
|
+ const submodule = modules.get(obj.target);
|
|
|
+ if (submodule && submodule.graph) {
|
|
|
+ this.modules.set(obj, submodule);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (const obj of graph.nodes) {
|
|
|
+ if (obj.op === 'placeholder') {
|
|
|
+ if (inputs) {
|
|
|
+ const value = this.value(obj);
|
|
|
+ const argument = new pytorch.Argument(obj.name, [value]);
|
|
|
+ target.inputs.push(argument);
|
|
|
+ }
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (obj.op === 'call_function') {
|
|
|
+ if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (obj.op === 'get_attr') {
|
|
|
+ if (this.modules.has(obj) && obj.users.size === 1) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (obj.op === 'output') {
|
|
|
+ for (const output of obj.args) {
|
|
|
+ if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const value = this.value(output);
|
|
|
+ const argument = new pytorch.Argument(output.name, [value]);
|
|
|
+ target.outputs.push(argument);
|
|
|
+ }
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const node = new pytorch.Node(this.execution, this.metadata, obj.name, null, obj, null, this);
|
|
|
+ target.nodes.push(node);
|
|
|
+ }
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+pytorch.Reader = class {
|
|
|
|
|
|
static async open(context) {
|
|
|
const types = [
|
|
|
- pytorch.Container.Zip,
|
|
|
- pytorch.Container.Pickle,
|
|
|
- pytorch.Container.Tar,
|
|
|
- pytorch.Container.data_pkl,
|
|
|
- pytorch.Container.torch_utils,
|
|
|
- pytorch.Container.Mobile,
|
|
|
- pytorch.Container.ModelJson,
|
|
|
- pytorch.Container.IR,
|
|
|
- pytorch.Container.Index,
|
|
|
- pytorch.Container.ExportedProgram
|
|
|
+ pytorch.Reader.Zip,
|
|
|
+ pytorch.Reader.Pickle,
|
|
|
+ pytorch.Reader.Tar,
|
|
|
+ pytorch.Reader.data_pkl,
|
|
|
+ pytorch.Reader.torch_utils,
|
|
|
+ pytorch.Reader.Mobile,
|
|
|
+ pytorch.Reader.ModelJson,
|
|
|
+ pytorch.Reader.IR,
|
|
|
+ pytorch.Reader.Index,
|
|
|
+ pytorch.Reader.ExportedProgram
|
|
|
];
|
|
|
for (const type of types) {
|
|
|
- /* eslint-disable no-await-in-loop */
|
|
|
- const container = await type.open(context);
|
|
|
- /* eslint-enable no-await-in-loop */
|
|
|
- if (container) {
|
|
|
- return container;
|
|
|
+ // eslint-disable-next-line no-await-in-loop
|
|
|
+ const reader = await type.open(context);
|
|
|
+ if (reader) {
|
|
|
+ return reader;
|
|
|
}
|
|
|
}
|
|
|
return null;
|
|
|
@@ -1032,12 +1106,12 @@ pytorch.Container = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.Tar = class extends pytorch.Container {
|
|
|
+pytorch.Reader.Tar = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const entries = await context.peek('tar');
|
|
|
if (entries instanceof Map && entries.has('pickle')) {
|
|
|
- return new pytorch.Container.Tar(entries);
|
|
|
+ return new pytorch.Reader.Tar(entries);
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
@@ -1060,13 +1134,13 @@ pytorch.Container.Tar = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.Pickle = class extends pytorch.Container {
|
|
|
+pytorch.Reader.Pickle = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const stream = context.stream;
|
|
|
const signature = [0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19];
|
|
|
if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
|
|
|
- return new pytorch.Container.Pickle(stream);
|
|
|
+ return new pytorch.Reader.Pickle(stream);
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
@@ -1090,7 +1164,7 @@ pytorch.Container.Pickle = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.data_pkl = class extends pytorch.Container {
|
|
|
+pytorch.Reader.data_pkl = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const obj = await context.peek('pkl');
|
|
|
@@ -1098,30 +1172,30 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
|
|
|
if (obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
|
|
|
const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
|
|
|
if (name.startsWith('__torch__.')) {
|
|
|
- return new pytorch.Container.data_pkl('', obj);
|
|
|
+ return new pytorch.Reader.data_pkl('', obj);
|
|
|
}
|
|
|
}
|
|
|
if (pytorch.Utility.isTensor(obj)) {
|
|
|
- return new pytorch.Container.data_pkl('tensor', obj);
|
|
|
+ return new pytorch.Reader.data_pkl('tensor', obj);
|
|
|
}
|
|
|
if (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) {
|
|
|
- return new pytorch.Container.data_pkl('tensor', obj);
|
|
|
+ return new pytorch.Reader.data_pkl('tensor', obj);
|
|
|
}
|
|
|
if (obj instanceof Map) {
|
|
|
const entries = Array.from(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
|
|
|
if (entries.length > 0) {
|
|
|
- return new pytorch.Container.data_pkl('tensor', obj);
|
|
|
+ return new pytorch.Reader.data_pkl('tensor', obj);
|
|
|
}
|
|
|
} else if (!Array.isArray(obj)) {
|
|
|
const entries = Object.entries(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
|
|
|
if (entries.length > 0) {
|
|
|
- return new pytorch.Container.data_pkl('tensor', obj);
|
|
|
+ return new pytorch.Reader.data_pkl('tensor', obj);
|
|
|
}
|
|
|
}
|
|
|
for (const key of ['', 'model', 'net']) {
|
|
|
const module = key === '' ? obj : obj[key];
|
|
|
if (module && module._modules && pytorch.Utility.isInstance(module._modules, 'collections.OrderedDict')) {
|
|
|
- return new pytorch.Container.data_pkl('module', module);
|
|
|
+ return new pytorch.Reader.data_pkl('module', module);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -1139,7 +1213,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.torch_utils = class extends pytorch.Container {
|
|
|
+pytorch.Reader.torch_utils = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const stream = context.stream;
|
|
|
@@ -1150,7 +1224,7 @@ pytorch.Container.torch_utils = class extends pytorch.Container {
|
|
|
if (content.indexOf('torch_utils') !== -1) {
|
|
|
const obj = await context.peek('pkl');
|
|
|
if (obj && Object.entries(obj).some(([, value]) => pytorch.Utility.isInstance(value, 'torch.nn.modules.module.Module'))) {
|
|
|
- return new pytorch.Container.torch_utils(obj);
|
|
|
+ return new pytorch.Reader.torch_utils(obj);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -1171,12 +1245,12 @@ pytorch.Container.torch_utils = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.Mobile = class extends pytorch.Container {
|
|
|
+pytorch.Reader.Mobile = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const reader = await context.peek('flatbuffers.binary');
|
|
|
if (reader && reader.identifier === 'PTMF') {
|
|
|
- return new pytorch.Container.Mobile(context);
|
|
|
+ return new pytorch.Reader.Mobile(context);
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
@@ -1203,7 +1277,7 @@ pytorch.Container.Mobile = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.Zip = class extends pytorch.Container {
|
|
|
+pytorch.Reader.Zip = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const entries = await context.peek('zip');
|
|
|
@@ -1220,10 +1294,10 @@ pytorch.Container.Zip = class extends pytorch.Container {
|
|
|
return null;
|
|
|
}
|
|
|
if (records.has('data.pkl')) {
|
|
|
- return new pytorch.Container.Zip(entries);
|
|
|
+ return new pytorch.Reader.Zip(entries);
|
|
|
}
|
|
|
if (records.has('.data/version') && !records.has('archive_format')) {
|
|
|
- return new pytorch.Container.Package(entries);
|
|
|
+ return new pytorch.Reader.Package(entries);
|
|
|
}
|
|
|
}
|
|
|
return null;
|
|
|
@@ -1267,7 +1341,7 @@ pytorch.Container.Zip = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.ModelJson = class extends pytorch.Container {
|
|
|
+pytorch.Reader.ModelJson = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const identifier = context.identifier;
|
|
|
@@ -1276,7 +1350,7 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
|
|
|
if (model && model.mainModule) {
|
|
|
const entries = new Map();
|
|
|
entries.set('model.json', context.stream);
|
|
|
- return new pytorch.Container.ModelJson(context, entries, model);
|
|
|
+ return new pytorch.Reader.ModelJson(context, entries, model);
|
|
|
}
|
|
|
}
|
|
|
return null;
|
|
|
@@ -1336,14 +1410,14 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.IR = class extends pytorch.Container {
|
|
|
+pytorch.Reader.IR = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const reader = await context.read('text', 0x100);
|
|
|
if (reader && reader.length > 0) {
|
|
|
const line = reader.read('\n');
|
|
|
if (line.startsWith('graph(')) {
|
|
|
- return new pytorch.Container.IR(context);
|
|
|
+ return new pytorch.Reader.IR(context);
|
|
|
}
|
|
|
}
|
|
|
return null;
|
|
|
@@ -1367,14 +1441,14 @@ pytorch.Container.IR = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.Index = class extends pytorch.Container {
|
|
|
+pytorch.Reader.Index = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const obj = await context.peek('json');
|
|
|
if (obj && obj.weight_map) {
|
|
|
const entries = Object.entries(obj.weight_map);
|
|
|
if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.bin'))) {
|
|
|
- return new pytorch.Container.Index(context, entries);
|
|
|
+ return new pytorch.Reader.Index(context, entries);
|
|
|
}
|
|
|
}
|
|
|
return null;
|
|
|
@@ -1424,18 +1498,18 @@ pytorch.Container.Index = class extends pytorch.Container {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.ExportedProgram = class extends pytorch.Container {
|
|
|
+pytorch.Reader.ExportedProgram = class extends pytorch.Reader {
|
|
|
|
|
|
static async open(context) {
|
|
|
const program = await context.peek('json');
|
|
|
if (program && program.schema_version && program.graph_module) {
|
|
|
- return new pytorch.Container.ExportedProgram(context, program);
|
|
|
+ return new pytorch.Reader.ExportedProgram(context, program);
|
|
|
}
|
|
|
if (context.identifier === 'archive_format' && context.stream && context.stream.length < 10) {
|
|
|
const buffer = context.stream.peek();
|
|
|
const archive_format = String.fromCharCode.apply(null, buffer);
|
|
|
if (archive_format === 'pt2') {
|
|
|
- return new pytorch.Container.ExportedProgram(context, null, context);
|
|
|
+ return new pytorch.Reader.ExportedProgram(context, null, context);
|
|
|
}
|
|
|
}
|
|
|
return null;
|
|
|
@@ -1530,9 +1604,8 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container {
|
|
|
const torch = this.execution.__import__('torch');
|
|
|
for (const exported_program of exported_programs.values()) {
|
|
|
if (exported_program.graph_module.graph.constants) {
|
|
|
- /* eslint-disable no-await-in-loop */
|
|
|
+ // eslint-disable-next-line no-await-in-loop
|
|
|
const zip = await import('./zip.js');
|
|
|
- /* eslint-enable no-await-in-loop */
|
|
|
const constants = exported_program.graph_module.graph.constants;
|
|
|
for (const key of Object.keys(constants)) {
|
|
|
const value = constants[key];
|
|
|
@@ -1591,9 +1664,8 @@ pytorch.Execution = class extends python.Execution {
|
|
|
constructor(sources, metadata) {
|
|
|
super(sources);
|
|
|
this._metadata = metadata;
|
|
|
- /* eslint-disable consistent-this */
|
|
|
+ // eslint-disable-next-line consistent-this
|
|
|
const execution = this;
|
|
|
- /* eslint-enable consistent-this */
|
|
|
const torch = this.torch;
|
|
|
this.registerFunction('torch.jit.jit_module_from_flatbuffer', (f) => {
|
|
|
const cu = new torch.jit.CompilationUnit();
|
|
|
@@ -1797,7 +1869,7 @@ pytorch.Execution = class extends python.Execution {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-pytorch.Container.Package = class extends pytorch.Container {
|
|
|
+pytorch.Reader.Package = class extends pytorch.Reader {
|
|
|
|
|
|
constructor(entries) {
|
|
|
super();
|