Ver Fonte

Update torch.export test file (#1211)

Lutz Roeder há 4 meses atrás
pai
commit
8f99ce4e86
3 ficheiros alterados com 203 adições e 41 exclusões
  1. 145 13
      source/python.js
  2. 57 28
      source/pytorch.js
  3. 1 0
      test/models.json

+ 145 - 13
source/python.js

@@ -7595,16 +7595,26 @@ python.Execution = class {
         });
         this.registerType('sympy.core.relational.GreaterThan', class extends sympy.core.relational._Greater {
             constructor(lhs, rhs) {
-                super(lhs, rhs, '>');
+                super(lhs, rhs, '>=');
             }
         });
         this.registerType('sympy.core.relational._Less', class extends sympy.core.relational._Inequality {
         });
         this.registerType('sympy.core.relational.LessThan', class extends sympy.core.relational.Relational {
+            constructor(lhs, rhs) {
+                super(lhs, rhs, '<=');
+            }
+        });
+        this.registerType('sympy.core.relational.StrictLessThan', class extends sympy.core.relational.Relational {
             constructor(lhs, rhs) {
                 super(lhs, rhs, '<');
             }
         });
+        this.registerType('sympy.core.relational.StrictGreaterThan', class extends sympy.core.relational.Relational {
+            constructor(lhs, rhs) {
+                super(lhs, rhs, '>');
+            }
+        });
         this.registerType('sympy.core.relational.Equality', class extends sympy.core.relational.Relational {
             constructor(lhs, rhs) {
                 super(lhs, rhs, '==');
@@ -7632,7 +7642,9 @@ python.Execution = class {
                         case 'Max': return new sympy.functions.elementary.miscellaneous.Max(...node.args.map((arg) => sympify(arg)));
                         case 'Integer': return new sympy.core.numbers.Integer(node.args[0].value);
                         case 'GreaterThan': return new sympy.core.relational.GreaterThan(sympify(node.args[0]), sympify(node.args[1]));
+                        case 'StrictGreaterThan': return new sympy.core.relational.StrictGreaterThan(sympify(node.args[0]), sympify(node.args[1]));
                         case 'LessThan': return new sympy.core.relational.LessThan(sympify(node.args[0]), sympify(node.args[1]));
+                        case 'StrictLessThan': return new sympy.core.relational.StrictLessThan(sympify(node.args[0]), sympify(node.args[1]));
                         case 'Equality': return new sympy.core.relational.Equality(sympify(node.args[0]), sympify(node.args[1]));
                         default: throw new python.Error(`Unsupported SymPy function '${node.func.id}'.`);
                     }
@@ -7652,15 +7664,22 @@ python.Execution = class {
                     if (node.op instanceof ast.Pow) {
                         return new sympy.core.power.Pow(sympify(node.left), sympify(node.right));
                     }
+                    throw new python.Error(`Unsupported SymPy BinOp op '${node.op.__class__.__name__}'.`);
                 }
                 if (node instanceof ast.Compare) {
                     const left = sympify(node.left);
                     const right = sympify(node.comparators[0]);
                     const [op] = node.ops;
                     if (op instanceof ast.Gt) {
+                        return new sympy.core.relational.StrictGreaterThan(left, right);
+                    }
+                    if (op instanceof ast.GtE) {
                         return new sympy.core.relational.GreaterThan(left, right);
                     }
                     if (op instanceof ast.Lt) {
+                        return new sympy.core.relational.StrictLessThan(left, right);
+                    }
+                    if (op instanceof ast.LtE) {
                         return new sympy.core.relational.LessThan(left, right);
                     }
                     if (op instanceof ast.Eq) {
@@ -18575,7 +18594,12 @@ python.Execution = class {
             COMPLEXFLOAT: 10,
             COMPLEXDOUBLE: 11,
             BOOL: 12,
-            BFLOAT16: 13
+            BFLOAT16: 13,
+            UINT16: 28,
+            FLOAT8E4M3FN: 29,
+            FLOAT8E5M2: 30,
+            FLOAT8E4M3FNUZ: 31,
+            FLOAT8E5M2FNUZ: 32,
         };
         torch._export.serde.schema.Layout = {
             Unknown: 0,
@@ -18875,6 +18899,111 @@ python.Execution = class {
                 }
             }
         });
+        this.registerFunction('torch.export.pt2_archive._package._load_state_dict', (f, model_name) => {
+            const legacy_file = `data/weights/${model_name}.pt`;
+            if (f.has(legacy_file)) {
+                return f.get(legacy_file);
+            }
+            const weights_config_file = `data/weights/${model_name}_weights_config.json`;
+            if (!f.has(weights_config_file)) {
+                return null;
+            }
+            const weights_config = f.get(weights_config_file);
+            const state_dict_file_map = torch.export.pt2_archive._package._build_file_map(f, weights_config, 'data/weights/');
+            const state_dict = new builtins.dict();
+            for (const [weight_fqn, payload_meta] of Object.entries(weights_config.config)) {
+                if (payload_meta.use_pickle) {
+                    const weight_bytes = f.get(`data/weights/${payload_meta.path_name}`);
+                    const weight_tensor = torch.load(weight_bytes);
+                    state_dict.set(weight_fqn, weight_tensor);
+                } else {
+                    const tensor_meta = payload_meta.tensor_meta;
+                    const tensor = state_dict_file_map.get(payload_meta.path_name);
+                    const sizes = tensor_meta.sizes.map((s) => s.as_int);
+                    const strides = tensor_meta.strides.map((s) => s.as_int);
+                    const storage_offset = tensor_meta.storage_offset.as_int;
+                    const weight_tensor = new torch.Tensor();
+                    weight_tensor.__setstate__([tensor.storage(), storage_offset, sizes, strides]);
+                    weight_tensor.requires_grad = tensor_meta.requires_grad || false;
+                    if (payload_meta.is_param) {
+                        state_dict.set(weight_fqn, new torch.nn.parameter.Parameter(weight_tensor, tensor_meta.requires_grad));
+                    } else {
+                        state_dict.set(weight_fqn, weight_tensor);
+                    }
+                }
+            }
+            return state_dict;
+        });
+        this.registerFunction('torch.export.pt2_archive._package._load_constants', (f, model_name) => {
+            const legacy_file = `data/constants/${model_name}.pt`;
+            if (f.has(legacy_file)) {
+                const entries = f.get(legacy_file);
+                return new builtins.dict(entries);
+            }
+            const constants_config_file = `data/constants/${model_name}_constants_config.json`;
+            if (!f.has(constants_config_file)) {
+                return null;
+            }
+            const constants_config = f.get(constants_config_file);
+            const constant_file_map = torch.export.pt2_archive._package._build_file_map(f, constants_config, 'data/constants/');
+            const constants = new builtins.dict();
+            for (const [constant_fqn, payload_meta] of Object.entries(constants_config.config)) {
+                const path_name = payload_meta.path_name;
+                if (path_name.startsWith('tensor_')) {
+                    if (payload_meta.use_pickle) {
+                        const constant_bytes = f.get(`data/constants/${payload_meta.path_name}`);
+                        const constant_tensor = torch.load(constant_bytes);
+                        constants.set(constant_fqn, constant_tensor);
+                    } else {
+                        const tensor_meta = payload_meta.tensor_meta;
+                        const tensor = constant_file_map.get(payload_meta.path_name);
+                        const sizes = tensor_meta.sizes.map((s) => s.as_int);
+                        const strides = tensor_meta.strides.map((s) => s.as_int);
+                        const storage_offset = tensor_meta.storage_offset.as_int;
+                        const constant_tensor = new torch.Tensor();
+                        constant_tensor.__setstate__([tensor.storage(), storage_offset, sizes, strides]);
+                        constants.set(constant_fqn, constant_tensor);
+                    }
+                } else if (payload_meta.path_name.startsWith('custom_obj_')) {
+                    const custom_obj_bytes = f.get(`data/constants/${payload_meta.path_name}`);
+                    const custom_obj = torch._C._pickle_load_obj(custom_obj_bytes);
+                    constants.set(constant_fqn, custom_obj);
+                }
+            }
+            return constants;
+        });
+        this.registerFunction('torch._export.serde.serialize.deserialize_scalar_type', (st) => {
+            if (!torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE.has(st)) {
+                throw new python.Error(`Unsupported scalar type '${st}'.`);
+            }
+            return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE.get(st);
+        });
+        this.registerFunction('torch.export.pt2_archive._package._build_file_map', (archive_reader, config, base_dir) => {
+            const file_map = new builtins.dict();
+            for (const payload_meta of Object.values(config.config)) {
+                if (payload_meta.use_pickle) {
+                    continue;
+                }
+                if (file_map.has(payload_meta.path_name)) {
+                    continue;
+                }
+                const tensor_bytes = archive_reader.get(`${base_dir}${payload_meta.path_name}`);
+                const tensor = torch.export.pt2_archive._package._create_flat_tensor_from_bytes(tensor_bytes, payload_meta.tensor_meta);
+                file_map.set(payload_meta.path_name, tensor);
+            }
+            return file_map;
+        });
+        this.registerFunction('torch.export.pt2_archive._package._create_flat_tensor_from_bytes', (tensor_bytes, tensor_meta) => {
+            const dtype = torch._export.serde.serialize.deserialize_scalar_type(tensor_meta.dtype);
+            const itemsize = dtype.itemsize();
+            const num_elements = tensor_bytes.length / itemsize;
+            const storage = new torch.storage.TypedStorage(num_elements, dtype);
+            storage._set_cdata(tensor_bytes);
+            const tensor = new torch.Tensor();
+            tensor.__setstate__([storage, 0, [num_elements], [1]]);
+            tensor.requires_grad = tensor_meta.requires_grad || false;
+            return tensor;
+        });
         this.registerFunction('torch.export.pt2_archive._package.load_pt2', (f, expected_opset_version) => {
             const exported_programs = new Map();
             for (const name of f.keys()) {
@@ -18882,9 +19011,9 @@ python.Execution = class {
                 if (match) {
                     const [, model_name] = match;
                     const serialized_exported_program = f.get(`models/${model_name}.json`);
-                    const serialized_state_dict = f.get(`data/weights/${model_name}.pt`);
-                    const serialized_constants = f.get(`data/constants/${model_name}.pt`);
-                    const serialized_example_inputs = f.get(`data/sample_inputs/${model_name}.pt`);
+                    const serialized_state_dict = torch.export.pt2_archive._package._load_state_dict(f, model_name);
+                    const serialized_constants = torch.export.pt2_archive._package._load_constants(f, model_name);
+                    const serialized_example_inputs = f.get(`data/sample_inputs/${model_name}.pt`, 'zip');
                     const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants, serialized_example_inputs);
                     const exported_program = torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
                     exported_programs.set(model_name, exported_program);
@@ -18942,7 +19071,10 @@ python.Execution = class {
             }
         });
         this.registerFunction('torch._export.serde.serialize.deserialize_torch_artifact', (serialized) => {
-            if (!serialized) {
+            if (serialized instanceof builtins.dict || serialized instanceof builtins.tuple) {
+                return serialized;
+            }
+            if (serialized === null || serialized.length === 0) {
                 return new builtins.dict();
             }
             const artifact = torch.load(serialized);
@@ -19217,8 +19349,8 @@ python.Execution = class {
                         this.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
                     */
                 this.example_inputs = null;
-                if (example_inputs && example_inputs.length > 0) {
-                    torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
+                if (example_inputs) {
+                    this.example_inputs = torch._export.serde.serialize.deserialize_torch_artifact(example_inputs);
                 }
                 this.deserialize_graph(serialized_graph_module.graph);
                 const module_call_graph = null; // this.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
@@ -19265,7 +19397,7 @@ python.Execution = class {
                 } else if (typ_ === 'as_tensor') {
                     return this.serialized_name_to_node.get(inp.as_tensor.name);
                 } else if (typ_ === 'as_scalar_type') {
-                    return torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type];
+                    return torch._export.serde.serialize.deserialize_scalar_type(inp.as_scalar_type);
                 } else if (typ_ === 'as_memory_format') {
                     return torch._export.serde.serialize._SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format];
                 } else if (typ_ === 'as_layout') {
@@ -19499,7 +19631,7 @@ python.Execution = class {
                     const sizes = tensor_meta.sizes.map((val) => this.deserialize_sym_int(val));
                     const strides = tensor_meta.strides.map((val) => this.deserialize_sym_int(val));
                     const device = this.deserialize_device(tensor_meta.device);
-                    const dtype = torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE[tensor_meta.dtype];
+                    const dtype = torch._export.serde.serialize.deserialize_scalar_type(tensor_meta.dtype);
                     return torch.empty_strided(sizes, strides, dtype, null, device);
                 } finally {
                     this.fake_tensor_mode.__exit__(null, null, null);
@@ -20350,13 +20482,13 @@ python.Execution = class {
         torch.uint16 = new torch.dtype(27, 'uint16', 2);
         torch.uint32 = new torch.dtype(28, 'uint32', 4);
         torch.uint64 = new torch.dtype(29, 'uint64', 8);
-        torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = Object.fromEntries([
+        torch._export.serde.serialize._SERIALIZE_TO_TORCH_DTYPE = new Map([
             ['uint8', 'BYTE'],
             ['int8', 'CHAR'], ['int16', 'SHORT'], ['int32', 'INT'], ['int64', 'LONG'],
             ['float16', 'HALF'], ['float32', 'FLOAT'], ['float64', 'DOUBLE'],
             ['complex32', 'COMPLEXHALF'], ['complex64', 'COMPLEXFLOAT'], ['complex128', 'COMPLEXDOUBLE'],
-            ['bool', 'BOOL'],
-            ['bfloat16', 'BFLOAT16']
+            ['bool', 'BOOL'], ['bfloat16', 'BFLOAT16'], ['uint16', 'UINT16'],
+            ['float8_e4m3fn','FLOAT8E4M3FN'], ['float8_e5m2','FLOAT8E5M2'], ['float8_e4m3fnuz','FLOAT8E4M3FNUZ'], ['float8_e5m2fnuz','FLOAT8E5M2FNUZ']
         ].map(([key, value]) => [torch._export.serde.schema.ScalarType[value], torch[key]]));
         torch.contiguous_format = new torch.memory_format('contiguous_format');
         torch.channels_last = new torch.memory_format('channels_last');

+ 57 - 28
source/pytorch.js

@@ -1436,26 +1436,49 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container {
                     const [, model_name] = match;
                     /* eslint-disable no-await-in-loop */
                     const model = await this.context.fetch(`models/${model_name}.json`);
-                    const constants = await this._fetch(`data/constants/${model_name}.pt`);
-                    const sample_inputs = await this._fetch(`data/sample_inputs/${model_name}.pt`);
-                    const weights = await this._fetch(`data/weights/${model_name}.pt`);
                     const exported_program = await model.read('json');
-                    /* eslint-enable no-await-in-loop */
                     exported_programs.set(model_name, exported_program);
                     f.set(`models/${model_name}.json`, exported_program);
-                    f.set(`data/weights/${model_name}.pt`, weights);
-                    f.set(`data/constants/${model_name}.pt`, constants);
+                    const sample_inputs = await this._fetch(`data/sample_inputs/${model_name}.pt`, 'zip');
                     f.set(`data/sample_inputs/${model_name}.pt`, sample_inputs);
+                    const weights_config = await this._fetch(`data/weights/${model_name}_weights_config.json`, 'json');
+                    if (weights_config) {
+                        f.set(`data/weights/${model_name}_weights_config.json`, weights_config);
+                        for (const payload_meta of Object.values(weights_config.config)) {
+                            const weight_data = await this._fetch(`data/weights/${payload_meta.path_name}`, 'binary');
+                            if (weight_data) {
+                                f.set(`data/weights/${payload_meta.path_name}`, weight_data);
+                            }
+                        }
+                    } else {
+                        const weights = await this._fetch(`data/weights/${model_name}.pt`, 'zip');
+                        f.set(`data/weights/${model_name}.pt`, weights);
+                    }
+                    const constants_config = await this._fetch(`data/constants/${model_name}_constants_config.json`, 'json');
+                    if (constants_config) {
+                        f.set(`data/constants/${model_name}_constants_config.json`, constants_config);
+                        for (const payload_meta of Object.values(constants_config.config)) {
+                            // eslint-enable no-await-in-loop
+                            const constant_data = await this._fetch(`data/constants/${payload_meta.path_name}`, 'binary');
+                            if (constant_data) {
+                                f.set(`data/constants/${payload_meta.path_name}`, constant_data);
+                            }
+                        }
+                    } else {
+                        const constants = await this._fetch(`data/constants/${model_name}.pt`);
+                        f.set(`data/constants/${model_name}.pt`, constants);
+                    }
+                    /* eslint-enable no-await-in-loop */
                 }
             }
-            const byteorder = await this._text('byteorder') || 'little';
+            const byteorder = await this._fetch('byteorder', 'text') || 'little';
             f.set('byteorder', byteorder);
         } else {
-            this.version = await this._text('version');
+            this.version = await this._fetch('version', 'text') || '';
             this.version = this.version.split('\n').shift().trim();
-            const weights = await this._fetch('serialized_state_dict.pt') || await this._fetch('serialized_state_dict.json');
-            const constants = await this._fetch('serialized_constants.pt') || await this._fetch('serialized_constants.json');
-            const sample_inputs = await this._fetch('serialized_example_inputs.pt');
+            const weights = await this._fetch('serialized_state_dict.pt', 'zip') || await this._fetch('serialized_state_dict.json', 'zip');
+            const constants = await this._fetch('serialized_constants.pt', 'zip') || await this._fetch('serialized_constants.json', 'zip');
+            const sample_inputs = await this._fetch('serialized_example_inputs.pt', 'zip');
             f.set('models/model.json', this.exported_program);
             f.set('data/weights/model.pt', weights);
             f.set('data/constants/model.pt', constants);
@@ -1505,31 +1528,37 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container {
         this.modules = pt2_contents.exported_programs;
     }
 
-    async _fetch(name) {
+    async _fetch(name, type) {
         try {
             const context = await this.context.fetch(name);
             if (context) {
-                return await context.peek('zip');
-            }
-        } catch {
-            // continue regardless of error
-        }
-        return null;
-    }
-
-    async _text(name) {
-        try {
-            const content = await this.context.fetch(name);
-            if (content) {
-                const reader = await content.read('text');
-                if (reader) {
-                    return reader.read();
+                switch (type) {
+                    case 'zip':
+                        return await context.peek('zip');
+                    case 'json':
+                        return await context.read('json');
+                    case 'text': {
+                        const reader = await context.read('text');
+                        if (reader) {
+                            return reader.read();
+                        }
+                        break;
+                    }
+                    case 'binary': {
+                        if (context && context.stream) {
+                            return context.stream.peek();
+                        }
+                        break;
+                    }
+                    default: {
+                        throw new pytorch.Error(`Unsupported context type '${type}.`);
+                    }
                 }
             }
         } catch {
             // continue regardless of error
         }
-        return '';
+        return null;
     }
 };
 

+ 1 - 0
test/models.json

@@ -5793,6 +5793,7 @@
     "target":   "draft_export.pt2",
     "source":   "https://github.com/user-attachments/files/22877643/draft_export.pt2.zip[draft_export.pt2]",
     "format":   "PyTorch Export v8.14",
+    "assert":   "model.modules[0].nodes[0].inputs[1].value[0].initializer.values.length == 400",
     "link":     "https://github.com/lutzroeder/netron/issues/1211"
   },
   {