|
|
@@ -7595,16 +7595,26 @@ python.Execution = class {
|
|
|
});
|
|
|
this.registerType('sympy.core.relational.GreaterThan', class extends sympy.core.relational._Greater {
|
|
|
constructor(lhs, rhs) {
|
|
|
- super(lhs, rhs, '>');
|
|
|
+ super(lhs, rhs, '>=');
|
|
|
}
|
|
|
});
|
|
|
this.registerType('sympy.core.relational._Less', class extends sympy.core.relational._Inequality {
|
|
|
});
|
|
|
this.registerType('sympy.core.relational.LessThan', class extends sympy.core.relational.Relational {
|
|
|
+ constructor(lhs, rhs) {
|
|
|
+ super(lhs, rhs, '<=');
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerType('sympy.core.relational.StrictLessThan', class extends sympy.core.relational.Relational {
|
|
|
constructor(lhs, rhs) {
|
|
|
super(lhs, rhs, '<');
|
|
|
}
|
|
|
});
|
|
|
+ this.registerType('sympy.core.relational.StrictGreaterThan', class extends sympy.core.relational.Relational {
|
|
|
+ constructor(lhs, rhs) {
|
|
|
+ super(lhs, rhs, '>');
|
|
|
+ }
|
|
|
+ });
|
|
|
this.registerType('sympy.core.relational.Equality', class extends sympy.core.relational.Relational {
|
|
|
constructor(lhs, rhs) {
|
|
|
super(lhs, rhs, '==');
|
|
|
@@ -7632,7 +7642,9 @@ python.Execution = class {
|
|
|
case 'Max': return new sympy.functions.elementary.miscellaneous.Max(...node.args.map((arg) => sympify(arg)));
|
|
|
case 'Integer': return new sympy.core.numbers.Integer(node.args[0].value);
|
|
|
case 'GreaterThan': return new sympy.core.relational.GreaterThan(sympify(node.args[0]), sympify(node.args[1]));
|
|
|
+ case 'StrictGreaterThan': return new sympy.core.relational.StrictGreaterThan(sympify(node.args[0]), sympify(node.args[1]));
|
|
|
case 'LessThan': return new sympy.core.relational.LessThan(sympify(node.args[0]), sympify(node.args[1]));
|
|
|
+ case 'StrictLessThan': return new sympy.core.relational.StrictLessThan(sympify(node.args[0]), sympify(node.args[1]));
|
|
|
case 'Equality': return new sympy.core.relational.Equality(sympify(node.args[0]), sympify(node.args[1]));
|
|
|
default: throw new python.Error(`Unsupported SymPy function '${node.func.id}'.`);
|
|
|
}
|
|
|
@@ -7652,15 +7664,22 @@ python.Execution = class {
|
|
|
if (node.op instanceof ast.Pow) {
|
|
|
return new sympy.core.power.Pow(sympify(node.left), sympify(node.right));
|
|
|
}
|
|
|
+ throw new python.Error(`Unsupported SymPy BinOp op '${node.op.__class__.__name__}'.`);
|
|
|
}
|
|
|
if (node instanceof ast.Compare) {
|
|
|
const left = sympify(node.left);
|
|
|
const right = sympify(node.comparators[0]);
|
|
|
const [op] = node.ops;
|
|
|
if (op instanceof ast.Gt) {
|
|
|
+ return new sympy.core.relational.StrictGreaterThan(left, right);
|
|
|
+ }
|
|
|
+ if (op instanceof ast.GtE) {
|
|
|
return new sympy.core.relational.GreaterThan(left, right);
|
|
|
}
|
|
|
if (op instanceof ast.Lt) {
|
|
|
+ return new sympy.core.relational.StrictLessThan(left, right);
|
|
|
+ }
|
|
|
+ if (op instanceof ast.LtE) {
|
|
|
return new sympy.core.relational.LessThan(left, right);
|
|
|
}
|
|
|
if (op instanceof ast.Eq) {
|
|
|
@@ -18575,7 +18594,12 @@ python.Execution = class {
|
|
|
COMPLEXFLOAT: 10,
|
|
|
COMPLEXDOUBLE: 11,
|
|
|
BOOL: 12,
|
|
|
- BFLOAT16: 13
|
|
|
+ BFLOAT16: 13,
|
|
|
+ UINT16: 28,
|
|
|
+ FLOAT8E4M3FN: 29,
|
|
|
+ FLOAT8E5M2: 30,
|
|
|
+ FLOAT8E4M3FNUZ: 31,
|
|
|
+ FLOAT8E5M2FNUZ: 32,
|
|
|
};
|
|
|
torch._export.serde.schema.Layout = {
|
|
|
Unknown: 0,
|
|
|
@@ -18875,6 +18899,111 @@ python.Execution = class {
|
|
|
}
|
|
|
}
|
|
|
});
|
|
|
+ this.registerFunction('torch.export.pt2_archive._package._load_state_dict', (f, model_name) => {
|
|
|
+ const legacy_file = `data/weights/${model_name}.pt`;
|
|
|
+ if (f.has(legacy_file)) {
|
|
|
+ return f.get(legacy_file);
|
|
|
+ }
|
|
|
+ const weights_config_file = `data/weights/${model_name}_weights_config.json`;
|
|
|
+ if (!f.has(weights_config_file)) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ const weights_config = f.get(weights_config_file);
|
|
|
+ const state_dict_file_map = torch.export.pt2_archive._package._build_file_map(f, weights_config, 'data/weights/');
|
|
|
+ const state_dict = new builtins.dict();
|
|
|
+ for (const [weight_fqn, payload_meta] of Object.entries(weights_config.config)) {
|
|
|
+ if (payload_meta.use_pickle) {
|
|
|
+ const weight_bytes = f.get(`data/weights/${payload_meta.path_name}`);
|
|
|
+ const weight_tensor = torch.load(weight_bytes);
|
|
|
+ state_dict.set(weight_fqn, weight_tensor);
|
|
|
+ } else {
|
|
|
+ const tensor_meta = payload_meta.tensor_meta;
|
|
|
+ const tensor = state_dict_file_map.get(payload_meta.path_name);
|
|
|
+ const sizes = tensor_meta.sizes.map((s) => s.as_int);
|
|
|
+ const strides = tensor_meta.strides.map((s) => s.as_int);
|
|
|
+ const storage_offset = tensor_meta.storage_offset.as_int;
|
|
|
+ const weight_tensor = new torch.Tensor();
|
|
|
+ weight_tensor.__setstate__([tensor.storage(), storage_offset, sizes, strides]);
|
|
|
+ weight_tensor.requires_grad = tensor_meta.requires_grad || false;
|
|
|
+ if (payload_meta.is_param) {
|
|
|
+ state_dict.set(weight_fqn, new torch.nn.parameter.Parameter(weight_tensor, tensor_meta.requires_grad));
|
|
|
+ } else {
|
|
|
+ state_dict.set(weight_fqn, weight_tensor);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return state_dict;
|
|
|
+ });
|
|
|
+ this.registerFunction('torch.export.pt2_archive._package._load_constants', (f, model_name) => {
|
|
|
+ const legacy_file = `data/constants/${model_name}.pt`;
|
|
|
+ if (f.has(legacy_file)) {
|
|
|
+ const entries = f.get(legacy_file);
|
|
|
+ return new builtins.dict(entries);
|
|
|
+ }
|
|
|
+ const constants_config_file = `data/constants/${model_name}_constants_config.json`;
|
|
|
+ if (!f.has(constants_config_file)) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ const constants_config = f.get(constants_config_file);
|
|
|
+ const constant_file_map = torch.export.pt2_archive._package._build_file_map(f, constants_config, 'data/constants/');
|
|
|
+ const constants = new builtins.dict();
|
|
|
+ for (const [constant_fqn, payload_meta] of Object.entries(constants_config.config)) {
|
|
|
+ const path_name = payload_meta.path_name;
|
|
|
+ if (path_name.startsWith('tensor_')) {
|
|
|
+ if (payload_meta.use_pickle) {
|
|
|
+ const constant_bytes = f.get(`data/constants/${payload_meta.path_name}`);
|
|
|
+ const constant_tensor = torch.load(constant_bytes);
|
|
|
+ constants.set(constant_fqn, constant_tensor);
|
|
|
+ } else {
|
|
|
+ const tensor_meta = payload_meta.tensor_meta;
|
|
|
+ const tensor = constant_file_map.get(payload_meta.path_name);
|
|
|
+ const sizes = tensor_meta.sizes.map((s) => s.as_int);
|
|
|
+ const strides = tensor_meta.strides.map((s) => s.as_int);
|
|
|
+ const storage_offset = tensor_meta.storage_offset.as_int;
|
|
|
+ const constant_tensor = new torch.Tensor();
|
|
|
+ constant_tensor.__setstate__([tensor.storage(), storage_offset, sizes, strides]);
|
|
|
+ constants.set(constant_fqn, constant_tensor);
|
|
|
+ }
|
|
|
+ } else if (payload_meta.path_name.startsWith('custom_obj_')) {
|
|
|
+ const custom_obj_bytes = f.get(`data/constants/${payload_meta.path_name}`);
|
|
|
+ const custom_obj = torch._C._pickle_load_obj(custom_obj_bytes);
|
|
|
+ constants.set(constant_fqn, custom_obj);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return constants;
|
|
|
+ });
|
|
|
+ this.registerFunction('torch._export.serde.serialize.deserialize_scalar_type', (st) => {
|
|
|
+ if (!torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE.has(st)) {
|
|
|
+ throw new python.Error(`Unsupported scalar type '${st}'.`);
|
|
|
+ }
|
|
|
+ return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE.get(st);
|
|
|
+ });
|
|
|
+ this.registerFunction('torch.export.pt2_archive._package._build_file_map', (archive_reader, config, base_dir) => {
|
|
|
+ const file_map = new builtins.dict();
|
|
|
+ for (const payload_meta of Object.values(config.config)) {
|
|
|
+ if (payload_meta.use_pickle) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (file_map.has(payload_meta.path_name)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const tensor_bytes = archive_reader.get(`${base_dir}${payload_meta.path_name}`);
|
|
|
+ const tensor = torch.export.pt2_archive._package._create_flat_tensor_from_bytes(tensor_bytes, payload_meta.tensor_meta);
|
|
|
+ file_map.set(payload_meta.path_name, tensor);
|
|
|
+ }
|
|
|
+ return file_map;
|
|
|
+ });
|
|
|
+ this.registerFunction('torch.export.pt2_archive._package._create_flat_tensor_from_bytes', (tensor_bytes, tensor_meta) => {
|
|
|
+ const dtype = torch._export.serde.serialize.deserialize_scalar_type(tensor_meta.dtype);
|
|
|
+ const itemsize = dtype.itemsize();
|
|
|
+ const num_elements = tensor_bytes.length / itemsize;
|
|
|
+ const storage = new torch.storage.TypedStorage(num_elements, dtype);
|
|
|
+ storage._set_cdata(tensor_bytes);
|
|
|
+ const tensor = new torch.Tensor();
|
|
|
+ tensor.__setstate__([storage, 0, [num_elements], [1]]);
|
|
|
+ tensor.requires_grad = tensor_meta.requires_grad || false;
|
|
|
+ return tensor;
|
|
|
+ });
|
|
|
this.registerFunction('torch.export.pt2_archive._package.load_pt2', (f, expected_opset_version) => {
|
|
|
const exported_programs = new Map();
|
|
|
for (const name of f.keys()) {
|
|
|
@@ -18882,9 +19011,9 @@ python.Execution = class {
|
|
|
if (match) {
|
|
|
const [, model_name] = match;
|
|
|
const serialized_exported_program = f.get(`models/${model_name}.json`);
|
|
|
- const serialized_state_dict = f.get(`data/weights/${model_name}.pt`);
|
|
|
- const serialized_constants = f.get(`data/constants/${model_name}.pt`);
|
|
|
- const serialized_example_inputs = f.get(`data/sample_inputs/${model_name}.pt`);
|
|
|
+ const serialized_state_dict = torch.export.pt2_archive._package._load_state_dict(f, model_name);
|
|
|
+ const serialized_constants = torch.export.pt2_archive._package._load_constants(f, model_name);
|
|
|
+ const serialized_example_inputs = f.get(`data/sample_inputs/${model_name}.pt`, 'zip');
|
|
|
const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants, serialized_example_inputs);
|
|
|
const exported_program = torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
|
|
|
exported_programs.set(model_name, exported_program);
|
|
|
@@ -18942,7 +19071,10 @@ python.Execution = class {
|
|
|
}
|
|
|
});
|
|
|
this.registerFunction('torch._export.serde.serialize.deserialize_torch_artifact', (serialized) => {
|
|
|
- if (!serialized) {
|
|
|
+ if (serialized instanceof builtins.dict || serialized instanceof builtins.tuple) {
|
|
|
+ return serialized;
|
|
|
+ }
|
|
|
+ if (serialized === null || serialized.length === 0) {
|
|
|
return new builtins.dict();
|
|
|
}
|
|
|
const artifact = torch.load(serialized);
|
|
|
@@ -19217,8 +19349,8 @@ python.Execution = class {
|
|
|
this.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
|
|
|
*/
|
|
|
this.example_inputs = null;
|
|
|
- if (example_inputs && example_inputs.length > 0) {
|
|
|
- torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
|
|
|
+ if (example_inputs) {
|
|
|
+ this.example_inputs = torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
|
|
|
}
|
|
|
this.deserialize_graph(serialized_graph_module.graph);
|
|
|
const module_call_graph = null; // this.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
|
|
|
@@ -19265,7 +19397,7 @@ python.Execution = class {
|
|
|
} else if (typ_ === 'as_tensor') {
|
|
|
return this.serialized_name_to_node.get(inp.as_tensor.name);
|
|
|
} else if (typ_ === 'as_scalar_type') {
|
|
|
- return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type];
|
|
|
+ return torch._export.serde.serialize.deserialize_scalar_type(inp.as_scalar_type);
|
|
|
} else if (typ_ === 'as_memory_format') {
|
|
|
return torch._export.serde.serialize._SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format];
|
|
|
} else if (typ_ === 'as_layout') {
|
|
|
@@ -19499,7 +19631,7 @@ python.Execution = class {
|
|
|
const sizes = tensor_meta.sizes.map((val) => this.deserialize_sym_int(val));
|
|
|
const strides = tensor_meta.strides.map((val) => this.deserialize_sym_int(val));
|
|
|
const device = this.deserialize_device(tensor_meta.device);
|
|
|
- const dtype = torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype];
|
|
|
+ const dtype = torch._export.serde.serialize.deserialize_scalar_type(tensor_meta.dtype);
|
|
|
return torch.empty_strided(sizes, strides, dtype, null, device);
|
|
|
} finally {
|
|
|
this.fake_tensor_mode.__exit__(null, null, null);
|
|
|
@@ -20350,13 +20482,13 @@ python.Execution = class {
|
|
|
torch.uint16 = new torch.dtype(27, 'uint16', 2);
|
|
|
torch.uint32 = new torch.dtype(28, 'uint32', 4);
|
|
|
torch.uint64 = new torch.dtype(29, 'uint64', 8);
|
|
|
- torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = Object.fromEntries([
|
|
|
+ torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = new Map([
|
|
|
['uint8', 'BYTE'],
|
|
|
['int8', 'CHAR'], ['int16', 'SHORT'], ['int32', 'INT'], ['int64', 'LONG'],
|
|
|
['float16', 'HALF'], ['float32', 'FLOAT'], ['float64', 'DOUBLE'],
|
|
|
['complex32', 'COMPLEXHALF'], ['complex64', 'COMPLEXFLOAT'], ['complex128', 'COMPLEXDOUBLE'],
|
|
|
- ['bool', 'BOOL'],
|
|
|
- ['bfloat16', 'BFLOAT16']
|
|
|
+ ['bool', 'BOOL'], ['bfloat16', 'BFLOAT16'], ['uint16', 'UINT16'],
|
|
|
+ ['float8_e4m3fn','FLOAT8E4M3FN'], ['float8_e5m2','FLOAT8E5M2'], ['float8_e4m3fnuz','FLOAT8E4M3FNUZ'], ['float8_e5m2fnuz','FLOAT8E5M2FNUZ']
|
|
|
].map(([key, value]) => [torch._export.serde.schema.ScalarType[value], torch[key]]));
|
|
|
torch.contiguous_format = new torch.memory_format('contiguous_format');
|
|
|
torch.channels_last = new torch.memory_format('channels_last');
|