|
|
@@ -5594,6 +5594,12 @@ python.Execution = class {
|
|
|
}
|
|
|
throw new python.Error('Unsupported torch.add expression type.');
|
|
|
});
|
|
|
+ this.registerFunction('torch.all', (input) => {
|
|
|
+ if (Array.isArray(input) && input.length === 0) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ throw new python.Error(`Unsupported 'torch.all' expression type.`);
|
|
|
+ });
|
|
|
this.registerFunction('torch.append', (list, value) => {
|
|
|
list.push(value);
|
|
|
return value;
|
|
|
@@ -5648,9 +5654,6 @@ python.Execution = class {
|
|
|
return NaN;
|
|
|
});
|
|
|
this.registerFunction('torch.eq', (left, right) => {
|
|
|
- const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
|
|
|
- left = value(left);
|
|
|
- right = value(right);
|
|
|
if (typeof left === 'string' && typeof right === 'string') {
|
|
|
return left === right;
|
|
|
}
|
|
|
@@ -5696,9 +5699,6 @@ python.Execution = class {
|
|
|
return self.replace(regex, '');
|
|
|
});
|
|
|
this.registerFunction('torch.gt', (left, right) => {
|
|
|
- const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
|
|
|
- left = value(left);
|
|
|
- right = value(right);
|
|
|
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
|
|
|
if (!isNaN(left) && !isNaN(right)) {
|
|
|
return left > right;
|
|
|
@@ -6020,9 +6020,6 @@ python.Execution = class {
|
|
|
throw new python.Error("Unsupported 'torch.remainder' expression type.");
|
|
|
});
|
|
|
this.registerFunction('torch.ne', (left, right) => {
|
|
|
- const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
|
|
|
- left = value(left);
|
|
|
- right = value(right);
|
|
|
if (typeof left === 'boolean' && typeof right === 'boolean') {
|
|
|
return left !== right;
|
|
|
}
|
|
|
@@ -6210,15 +6207,14 @@ python.Execution = class {
|
|
|
});
|
|
|
this.registerType('torch.ClassType', class extends torch.Type {
|
|
|
constructor(qualified_name, cu, is_module) {
|
|
|
- super('ClassType');
|
|
|
- this._qualified_name = qualified_name;
|
|
|
+ super('ClassType', qualified_name);
|
|
|
this._is_module = is_module;
|
|
|
this._attributes = new Map();
|
|
|
this._methods = new Map();
|
|
|
this._staticmethods = new Map();
|
|
|
}
|
|
|
qualified_name() {
|
|
|
- return this._qualified_name;
|
|
|
+ return this.annotation_str;
|
|
|
}
|
|
|
name() {
|
|
|
return this._qualified_name.split('.').pop();
|
|
|
@@ -6349,20 +6345,39 @@ python.Execution = class {
|
|
|
}
|
|
|
});
|
|
|
this.registerType('torch.TupleType', class extends torch.Type {
|
|
|
- constructor(elements) {
|
|
|
- super('TupleType');
|
|
|
+ constructor(elements, annotation_str, schema) {
|
|
|
+ super('TupleType', annotation_str);
|
|
|
this._elements = elements;
|
|
|
+ this._schema = schema;
|
|
|
}
|
|
|
static get(elements) {
|
|
|
return new torch.TupleType(elements);
|
|
|
}
|
|
|
+ static createNamed(qualified_name, field_names, field_types /*, field_defaults */) {
|
|
|
+ const args = [];
|
|
|
+ for (let i = 0; i < field_names.length; i++) {
|
|
|
+ const arg = new torch.Argument(field_names[i], field_types[i], field_types[i]);
|
|
|
+ args.push(arg);
|
|
|
+ }
|
|
|
+ const schema = new torch.FunctionSchema(qualified_name, args);
|
|
|
+ return new torch.TupleType(field_types, qualified_name, schema);
|
|
|
+ }
|
|
|
elements() {
|
|
|
return this._elements;
|
|
|
}
|
|
|
+ schema() {
|
|
|
+ return this._schema;
|
|
|
+ }
|
|
|
str() {
|
|
|
+ if (this._schema) {
|
|
|
+ return `NamedTuple(...)`;
|
|
|
+ }
|
|
|
return `(${this.elements().map((elem) => elem.str()).join(', ')})`;
|
|
|
}
|
|
|
__str__() {
|
|
|
+ if (this.annotation_str) {
|
|
|
+ return this.annotation_str;
|
|
|
+ }
|
|
|
return `Tuple[${this.elements().map((elem) => elem.__str__()).join(', ')}]`;
|
|
|
}
|
|
|
});
|
|
|
@@ -7074,11 +7089,11 @@ python.Execution = class {
|
|
|
const index = name.indexOf('(');
|
|
|
if (index === -1) {
|
|
|
this._name = name;
|
|
|
- this._overload_name = overload_name;
|
|
|
- this._arguments = args;
|
|
|
- this._returns = returns;
|
|
|
- this._is_vararg = is_vararg;
|
|
|
- this._is_varret = is_varret;
|
|
|
+ this._overload_name = overload_name || '';
|
|
|
+ this._arguments = args || [];
|
|
|
+ this._returns = returns || [];
|
|
|
+ this._is_vararg = is_vararg || false;
|
|
|
+ this._is_varret = is_varret || false;
|
|
|
} else {
|
|
|
const value = name.substring(0, index).trim();
|
|
|
const dot = value.indexOf('.');
|
|
|
@@ -7689,22 +7704,32 @@ python.Execution = class {
|
|
|
this.register('torch.jit._script');
|
|
|
this.register('torch.jit._trace');
|
|
|
this.registerType('torch.jit.Source', class {
|
|
|
- constructor(text) {
|
|
|
- this._text = text;
|
|
|
+ constructor(text_view, filename) {
|
|
|
+ this._text_view = text_view;
|
|
|
+ this._filename = filename;
|
|
|
+ }
|
|
|
+ text_str() {
|
|
|
+ return this._text_view;
|
|
|
+ }
|
|
|
+ filename() {
|
|
|
+ return this._filename;
|
|
|
}
|
|
|
});
|
|
|
- this.registerType('torch.jit.SourceLoader', class {
|
|
|
- constructor(reader, code_prefix) {
|
|
|
- this._reader = reader;
|
|
|
- this._code_prefix = code_prefix;
|
|
|
+ this.registerType('torch.jit.QualifiedName', class {
|
|
|
+ constructor(name) {
|
|
|
+ const index = name.lastIndexOf('.');
|
|
|
+ this._qualifiedName = name;
|
|
|
+ this._prefix = index === -1 ? '' : name.substring(0, index);
|
|
|
+ this._name = index === -1 ? name : name.substring(index + 1);
|
|
|
}
|
|
|
- loadSource(qualifier) {
|
|
|
- const path = `${this._code_prefix}/${qualifier}.py`;
|
|
|
- if (this._reader.has_record(path)) {
|
|
|
- const data = this._reader.get_record(path);
|
|
|
- return new torch.jit.Source(data);
|
|
|
- }
|
|
|
- return null;
|
|
|
+ qualifiedName() {
|
|
|
+ return this._qualifiedName; // "foo.bar.baz"
|
|
|
+ }
|
|
|
+ prefix() {
|
|
|
+ return this._prefix; // "foo.bar"
|
|
|
+ }
|
|
|
+ name() {
|
|
|
+ return this._name; // "baz"
|
|
|
}
|
|
|
});
|
|
|
this.registerType('torch.jit.SourceImporter', class {
|
|
|
@@ -7713,17 +7738,103 @@ python.Execution = class {
|
|
|
this._constant_table = constant_table;
|
|
|
this._source_loader = source_loader;
|
|
|
this._version = version;
|
|
|
+ this._loaded_sources = new Set();
|
|
|
+ this._to_be_defined = new Map();
|
|
|
}
|
|
|
loadType(/* name */) {
|
|
|
//
|
|
|
}
|
|
|
resolveType(name) {
|
|
|
- return this.findNamedType(new torch.jit.QualifiedName(name));
|
|
|
+ name = new torch.jit.QualifiedName(name);
|
|
|
+ return this.findNamedType(name);
|
|
|
}
|
|
|
findNamedType(name) {
|
|
|
+ // if (auto custom_class = getCustomClass(name.qualifiedName())) {
|
|
|
+ // return custom_class;
|
|
|
+ // }
|
|
|
this.parseSourceIfNeeded(name.prefix());
|
|
|
+ const key = name.qualifiedName();
|
|
|
+ const it = this._to_be_defined.get(key);
|
|
|
+ if (it && it.type === 'class') {
|
|
|
+ this._to_be_defined.delete(key);
|
|
|
+ this.importNamedType(name.prefix(), it);
|
|
|
+ }
|
|
|
+ return this._cu.get_type(name);
|
|
|
+ }
|
|
|
+ importNamedType(qualifier, class_def) {
|
|
|
+ const qualified_name = new torch.jit.QualifiedName(`${qualifier}.${class_def.name}`);
|
|
|
+ if (class_def.bases.length === 0) {
|
|
|
+ this.importClass(qualified_name, class_def, false);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ const superclass_name = class_def.bases[0].value;
|
|
|
+ if (superclass_name === 'Module') {
|
|
|
+ this.importClass(qualified_name, class_def, true);
|
|
|
+ } else if (superclass_name === 'NamedTuple') {
|
|
|
+ this.importNamedTuple(qualified_name, class_def);
|
|
|
+ } else if (superclass_name === 'Interface') {
|
|
|
+ // cu_->define_interface(qualified_name, class_def, shared_from_this(), is_module=false);
|
|
|
+ } else if (superclass_name === 'ModuleInterface') {
|
|
|
+ // cu_->define_interface(qualified_name, class_def, shared_from_this(), is_module=true);
|
|
|
+ } else if (superclass_name === 'Enum') {
|
|
|
+ // importEnum(qualified_name, class_def);
|
|
|
+ } else {
|
|
|
+ throw new python.Error('TorchScript does not support class inheritance.');
|
|
|
+ }
|
|
|
+ }
|
|
|
+ importClass(qualified_name, class_def, is_module) {
|
|
|
+ if (qualified_name.prefix().startsWith('__torch__.torch.classes')) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ const class_type = new torch.ClassType(qualified_name.qualifiedName(), this._cu, is_module);
|
|
|
+ for (const entry of class_def.body.statements) {
|
|
|
+ if (entry.type === 'var') {
|
|
|
+ const variableType = this._cu.execution.type(entry.variableType, null);
|
|
|
+ class_type.addAttribute(entry.name, variableType);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // debugger;
|
|
|
+ this._cu.register_type(class_type);
|
|
|
}
|
|
|
- parseSourceIfNeeded(/* qualifier */) {
|
|
|
+ importNamedTuple(qualified_name, named_tuple_def) {
|
|
|
+ const field_names = [];
|
|
|
+ const field_types = [];
|
|
|
+ const field_defaults = [];
|
|
|
+ for (const statement of named_tuple_def.body.statements) {
|
|
|
+ if (statement.type !== 'var') {
|
|
|
+ throw new python.Error('Unexpected statement in NamedTuple body.');
|
|
|
+ }
|
|
|
+ field_names.push(statement.name);
|
|
|
+ field_types.push(this._cu.execution.type(statement.variableType));
|
|
|
+ }
|
|
|
+ const tt = torch.TupleType.createNamed(qualified_name.qualifiedName(), field_names, field_types, field_defaults);
|
|
|
+ this._cu.register_type(tt);
|
|
|
+ }
|
|
|
+ parseSourceIfNeeded(qualifier) {
|
|
|
+ if (!qualifier || this._loaded_sources.has(qualifier)) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ this._loaded_sources.add(qualifier);
|
|
|
+ const src = this._source_loader(qualifier);
|
|
|
+ if (!src) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ const program = this._cu.execution.parse(src.filename(), src.text_str(), null);
|
|
|
+ for (const statement of program.body) {
|
|
|
+ switch (statement.type) {
|
|
|
+ case 'def': {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'class': {
|
|
|
+ const name = `${qualifier}.${statement.name}`;
|
|
|
+ this._to_be_defined.set(name, statement);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ default: {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
});
|
|
|
this.registerType('torch.jit.ScriptModuleDeserializer', class {
|
|
|
@@ -7734,9 +7845,11 @@ python.Execution = class {
|
|
|
this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/';
|
|
|
this._pickle_dir_prefix = pickle_dir_prefix || '';
|
|
|
this._tensor_dir_prefix = tensor_dir_prefix || '';
|
|
|
+ const SourceLoader = (qualifier) => {
|
|
|
+ return this.findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier);
|
|
|
+ };
|
|
|
this._source_importer = new torch.jit.SourceImporter(
|
|
|
- this._compilation_unit, this._constants_table,
|
|
|
- new torch.jit.SourceLoader(this._reader, this._code_prefix), reader.version());
|
|
|
+ this._compilation_unit, this._constants_table, SourceLoader, reader.version());
|
|
|
}
|
|
|
deserialize() {
|
|
|
const execution = this._compilation_unit.execution;
|
|
|
@@ -7755,6 +7868,13 @@ python.Execution = class {
|
|
|
execution.builtins.ops = torch.ops;
|
|
|
execution.builtins.inf = torch.inf;
|
|
|
execution.builtins.CONSTANTS = {};
|
|
|
+ execution._resolver = this._source_importer;
|
|
|
+ const known_types = ['__torch__.torch.classes._nnapi.Compilation'];
|
|
|
+ for (const name of known_types) {
|
|
|
+ const type = new torch.ClassType(name, this._compilation_unit, false);
|
|
|
+ type.addMethod(new torch.FunctionSchema('init(Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> ()'));
|
|
|
+ this._compilation_unit.register_type(type);
|
|
|
+ }
|
|
|
if (this._reader.has_record('model.json')) {
|
|
|
return this.LEGACY_deserialize();
|
|
|
}
|
|
|
@@ -7914,6 +8034,17 @@ python.Execution = class {
|
|
|
};
|
|
|
return unpickler.load();
|
|
|
}
|
|
|
+ qualifierToArchivePath(qualifier, export_prefix) {
|
|
|
+ return `${export_prefix}${qualifier.replace(/\./g, '/')}.py`;
|
|
|
+ }
|
|
|
+ findSourceInArchiveFromQualifier(reader, export_prefix, qualifier) {
|
|
|
+ const path = this.qualifierToArchivePath(qualifier, export_prefix);
|
|
|
+ if (!reader.has_record(path)) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ const data = reader.get_record(path);
|
|
|
+ return new torch.jit.Source(data.peek(), path);
|
|
|
+ }
|
|
|
});
|
|
|
this.registerType('torch.package.PackageImporter', class {
|
|
|
constructor(reader) {
|
|
|
@@ -8215,6 +8346,9 @@ python.Execution = class {
|
|
|
this._functions = new Map();
|
|
|
this._classes = new Map();
|
|
|
}
|
|
|
+ register_type(namedType) {
|
|
|
+ this._classes.set(namedType.annotation_str, namedType);
|
|
|
+ }
|
|
|
register_function(fn) {
|
|
|
this._functions.set(fn.name, fn);
|
|
|
}
|
|
|
@@ -8228,14 +8362,11 @@ python.Execution = class {
|
|
|
}
|
|
|
}
|
|
|
get_type(name) {
|
|
|
- return this._classes.get(name);
|
|
|
+ return this._classes.get(name.qualifiedName());
|
|
|
}
|
|
|
get_class(name) {
|
|
|
return this.get_type(name);
|
|
|
}
|
|
|
- register_type(name, cls) {
|
|
|
- this._classes.set(name, cls);
|
|
|
- }
|
|
|
});
|
|
|
this.registerType('torch.jit._script.ScriptModule', class extends torch.nn.modules.module.Module {});
|
|
|
this.registerType('torch.jit._trace.TracedModule', class extends torch.jit._script.ScriptModule {});
|
|
|
@@ -8399,7 +8530,7 @@ python.Execution = class {
|
|
|
if (!cls) {
|
|
|
const name = obj_type.type_name;
|
|
|
if (name.startsWith('__torch__') || name.startsWith('torch.jit')) {
|
|
|
- cls = this._cu.get_class(name);
|
|
|
+ cls = this._cu.get_class(new torch.jit.QualifiedName(name));
|
|
|
if (!cls) {
|
|
|
const torch = this._torch;
|
|
|
cls = new torch.ClassType(name, this._cu, true);
|
|
|
@@ -10247,13 +10378,6 @@ python.Execution = class {
|
|
|
return this._builtins;
|
|
|
}
|
|
|
|
|
|
- source(file) {
|
|
|
- return this._sources.has(file) ? this._sources.get(file) : null;
|
|
|
- }
|
|
|
-
|
|
|
- debug(/* file */) {
|
|
|
- }
|
|
|
-
|
|
|
exec(code , context) {
|
|
|
const reader = new python.Parser(code, '', null);
|
|
|
const program = reader.parse();
|
|
|
@@ -10263,21 +10387,35 @@ python.Execution = class {
|
|
|
this.block(program.body, context);
|
|
|
}
|
|
|
|
|
|
- parse(file) {
|
|
|
+ debug(/* file */) {
|
|
|
+ }
|
|
|
+
|
|
|
+ source(file) {
|
|
|
+ if (this._sources.has(file)) {
|
|
|
+ return this._sources.get(file);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ read(file) {
|
|
|
const buffer = this.source(file);
|
|
|
if (buffer) {
|
|
|
const debug = this.debug(file);
|
|
|
- const code = this._utf8Decoder.decode(buffer);
|
|
|
- const parser = new python.Parser(code, file, debug);
|
|
|
- const program = parser.parse();
|
|
|
- if (!program) {
|
|
|
- throw new python.Error(`Module '${file}' parse error.`);
|
|
|
- }
|
|
|
- return program;
|
|
|
+ return this.parse(file, buffer, debug);
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
|
|
|
+ parse(file, buffer, debug) {
|
|
|
+ const code = this._utf8Decoder.decode(buffer);
|
|
|
+ const parser = new python.Parser(code, file, debug);
|
|
|
+ const program = parser.parse();
|
|
|
+ if (!program) {
|
|
|
+ throw new python.Error(`Module '${file}' parse error.`);
|
|
|
+ }
|
|
|
+ return program;
|
|
|
+ }
|
|
|
+
|
|
|
import(name, current, level) {
|
|
|
if (level) {
|
|
|
let bits = current.split('.');
|
|
|
@@ -10303,7 +10441,7 @@ python.Execution = class {
|
|
|
const path = name.split('.').join('/');
|
|
|
module.__path__ = [path];
|
|
|
const file = `${path}.py`;
|
|
|
- const program = this.parse(file);
|
|
|
+ const program = this.read(file);
|
|
|
if (program) {
|
|
|
module.__file__ = file;
|
|
|
for (const [name, value] of Object.entries(this.builtins)) {
|