Lutz Roeder 1 anno fa
parent
commit
755ba204bf
1 ha cambiato i file con 8 aggiunte e 2 eliminazioni
  1. 8 2
      source/python.js

+ 8 - 2
source/python.js

@@ -6314,9 +6314,15 @@ python.Execution = class {
         this.registerType('torch.backends.cudnn.rnn.Unserializable', class {});
         this.registerFunction('torch.distributed._shard.sharded_tensor.pre_load_state_dict_hook');
         this.registerFunction('torch.distributed._shard.sharded_tensor.state_dict_hook');
-        this.registerType('torch.distributed.algorithms.join._JoinConfig', class {});
         this.registerFunction('torch.distributed._sharded_tensor.state_dict_hook');
         this.registerFunction('torch.distributed._sharded_tensor.pre_load_state_dict_hook');
+        this.registerType('torch.distributed.algorithms.join._JoinConfig', class {});
+        this.registerType('torch.distributed.remote_device._remote_device', class {});
+        this.registerType('torch.distributed._shard.metadata.ShardMetadata', class {});
+        this.registerType('torch.distributed._shard.sharded_tensor.api.ShardedTensor', class {});
+        this.registerType('torch.distributed._shard.sharded_tensor.metadata.ShardedTensorMetadata', class {});
+        this.registerType('torch.distributed._shard.sharded_tensor.metadata.TensorProperties', class {});
+        this.registerType('torch.distributed._shard.sharded_tensor.shard.Shard', class {});
         this.registerType('torch.distributed._tensor.api.DTensor', class extends torch._C._TensorMeta {});
         this.registerType('torch.distributed._tensor.placement_types.DTensorSpec', class {});
         this.registerType('torch.distributed._tensor.placement_types.Shard', class {});
@@ -12982,7 +12988,7 @@ python.Execution = class {
             if (!render_errors) {
                 return torch._C.matchSchemas(schemas, loc, graph, args, kwargs, self, /*render_errors=*/true);
             }
-            throw new python.Error('No matching schema found.');
+            throw new python.Error(`No matching schema '${schemas[0].name}' found.`);
         });
         this.registerFunction('torch._C.emitBuiltinCall', (loc, graph, name, args, kwargs, self) => {
             const variants = torch._C.getAllOperatorsFor(name);