Ver Fonte

Add Pickle test file (#901)

Lutz Roeder há 1 mês atrás
pai
commit
43bf2a7a7a
3 ficheiros alterados com 40 adições e 0 exclusões
  1. 1 0
      source/pickle.js
  2. 32 0
      source/python.js
  3. 7 0
      test/models.json

+ 1 - 0
source/pickle.js

@@ -34,6 +34,7 @@ pickle.ModelFactory = class {
                 ['cuml.ensemble.randomforestclassifier.RandomForestClassifier', 'cuML'],
                 ['shap.explainers._linear.LinearExplainer', 'SHAP'],
                 ['gensim.models.word2vec.Word2Vec', 'Gensim'],
+                ['ray.rllib.algorithms.ppo.ppo.PPOConfig', 'Ray RLlib'],
                 ['builtins.bytearray', 'Pickle'],
                 ['builtins.dict', 'Pickle'],
                 ['collections.OrderedDict', 'Pickle'],

+ 32 - 0
source/python.js

@@ -4689,7 +4689,39 @@ python.Execution = class {
         this.registerFunction('cloudpickle.cloudpickle_fast._function_setstate');
         const ray = this.register('ray');
         this.register('ray.cloudpickle.cloudpickle');
+        this.register('ray.cloudpickle.cloudpickle_fast');
         ray.cloudpickle.cloudpickle._builtin_type = cloudpickle.cloudpickle._builtin_type;
+        ray.cloudpickle.cloudpickle._fill_function = cloudpickle.cloudpickle._fill_function;
+        ray.cloudpickle.cloudpickle._make_cell = cloudpickle.cloudpickle._make_cell;
+        ray.cloudpickle.cloudpickle._make_function = cloudpickle.cloudpickle._make_function;
+        ray.cloudpickle.cloudpickle._make_skel_func = cloudpickle.cloudpickle._make_skel_func;
+        ray.cloudpickle.cloudpickle._make_skeleton_class = cloudpickle.cloudpickle._make_skeleton_class;
+        ray.cloudpickle.cloudpickle._make_empty_cell = cloudpickle.cloudpickle._make_empty_cell;
+        ray.cloudpickle.cloudpickle._empty_cell_value = cloudpickle.cloudpickle._empty_cell_value;
+        ray.cloudpickle.cloudpickle._class_setstate = cloudpickle.cloudpickle._class_setstate;
+        ray.cloudpickle.cloudpickle._function_setstate = cloudpickle.cloudpickle._function_setstate;
+        ray.cloudpickle.cloudpickle._lookup_class_or_track = cloudpickle.cloudpickle._lookup_class_or_track;
+        ray.cloudpickle.cloudpickle_fast._class_setstate = cloudpickle.cloudpickle._class_setstate;
+        ray.cloudpickle.cloudpickle_fast._function_setstate = cloudpickle.cloudpickle._function_setstate;
+        this.registerType('ray.rllib.algorithms.ppo.ppo.PPO', class {});
+        this.registerType('ray.rllib.algorithms.ppo.ppo.PPOConfig', class {});
+        this.registerType('ray.rllib.algorithms.algorithm_config.AlgorithmConfig', class {});
+        this.registerFunction('ray.rllib.algorithms.algorithm_config.AlgorithmConfig.DEFAULT_POLICY_MAPPING_FN');
+        this.registerType('ray.rllib.algorithms.algorithm_config.TorchCompileWhatToCompile', class {});
+        this.registerType('ray.rllib.evaluation.collectors.simple_list_collector.SimpleListCollector', class {});
+        this.registerType('ray.rllib.callbacks.callbacks.RLlibCallback', class {});
+        this.registerType('ray.rllib.core.learner.learner.TorchCompileWhatToCompile', class {});
+        this.registerType('ray.rllib.policy.policy.PolicySpec', class {});
+        this.registerType('ray.rllib.policy.sample_batch.SampleBatch', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.mean.MeanStats', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.ema.EmaStats', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.min.MinStats', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.max.MaxStats', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.sum.SumStats', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.lifetime_sum.LifetimeSumStats', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.percentiles.PercentilesStats', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.item.ItemStats', class {});
+        this.registerType('ray.rllib.utils.metrics.stats.item_series.ItemSeriesStats', class {});
         this.registerType('collections.Counter', class {});
         this.registerFunction('collections.defaultdict', (/* default_factory */) => {
             return {};

+ 7 - 0
test/models.json

@@ -5404,6 +5404,13 @@
     "format":   "QNN",
     "link":     "https://github.com/lutzroeder/netron/issues/1283"
   },
+  {
+    "type":     "pickle",
+    "target":   "algorithm_state.pkl",
+    "source":   "https://github.com/user-attachments/files/24619982/algorithm_state.pkl.zip[algorithm_state.pkl]",
+    "format":   "Pickle",
+    "link":     "https://github.com/lutzroeder/netron/issues/901"
+  },
   {
     "type":     "pickle",
     "target":   "batches.meta",