|
|
@@ -12,14 +12,13 @@ pytorch.ModelFactory = class {
|
|
|
return pytorch.Container.open(context);
|
|
|
}
|
|
|
|
|
|
- async open(context, match) {
|
|
|
+ async open(context, target) {
|
|
|
const metadata = await pytorch.Metadata.open(context);
|
|
|
- const container = match;
|
|
|
- container.metadata = metadata;
|
|
|
+ const container = target;
|
|
|
container.on('resolve', (_, name) => {
|
|
|
context.exception(new pytorch.Error("Unknown type name '" + name + "'."), false);
|
|
|
});
|
|
|
- await container.read();
|
|
|
+ await container.read(metadata);
|
|
|
return new pytorch.Model(metadata, container);
|
|
|
}
|
|
|
};
|
|
|
@@ -31,7 +30,8 @@ pytorch.Model = class {
|
|
|
this._producer = container.producer || '';
|
|
|
this._graphs = [];
|
|
|
for (const entry of container.modules) {
|
|
|
- this._graphs.push(new pytorch.Graph(metadata, entry[0], entry[1], container));
|
|
|
+ const graph = new pytorch.Graph(metadata, entry[0], entry[1]);
|
|
|
+ this._graphs.push(graph);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -808,17 +808,12 @@ pytorch.Container = class {
|
|
|
}
|
|
|
|
|
|
constructor() {
|
|
|
- this._metadata = null;
|
|
|
this._events = [];
|
|
|
}
|
|
|
|
|
|
async read() {
|
|
|
}
|
|
|
|
|
|
- set metadata(value) {
|
|
|
- this._metadata = value;
|
|
|
- }
|
|
|
-
|
|
|
on(event, callback) {
|
|
|
this._events.push([ event, callback ]);
|
|
|
}
|
|
|
@@ -988,10 +983,10 @@ pytorch.Container.Mobile = class extends pytorch.Container {
|
|
|
this._context = context;
|
|
|
}
|
|
|
|
|
|
- async read() {
|
|
|
+ async read(metadata) {
|
|
|
await this._context.require('./pytorch-schema');
|
|
|
this._modules = new Map();
|
|
|
- const execution = new pytorch.jit.Execution(null, this._metadata);
|
|
|
+ const execution = new pytorch.jit.Execution(null, metadata);
|
|
|
for (const event in this._events) {
|
|
|
execution.on(event[0], event[1]);
|
|
|
}
|
|
|
@@ -1065,8 +1060,8 @@ pytorch.Container.Zip = class extends pytorch.Container {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- async read() {
|
|
|
- const execution = new pytorch.jit.Execution(null, this._metadata);
|
|
|
+ async read(metadata) {
|
|
|
+ const execution = new pytorch.jit.Execution(null, metadata);
|
|
|
for (const event in this._events) {
|
|
|
execution.on(event[0], event[1]);
|
|
|
}
|