Prechádzať zdrojové kódy

Format name validation

Lutz Roeder 1 mesiac pred
rodič
commit
01c449bf58
5 zmenil súbory, kde vykonal 36 pridanie a 5 odobranie
  1. 1 1
      source/acuity.js
  2. 9 1
      source/keras.js
  3. 5 0
      source/onnx.js
  4. 11 2
      source/view.js
  5. 10 1
      test/models.json

+ 1 - 1
source/acuity.js

@@ -21,7 +21,7 @@ acuity.Model = class {
 
     constructor(metadata, model, data, quantization) {
         this.name = model.MetaData.Name;
-        this.format = `Acuity v${model.MetaData.AcuityVersion}`;
+        this.format = `Acuity${model.MetaData && model.MetaData.AcuityVersion ? ` v${model.MetaData.AcuityVersion}` : ''}`;
         this.runtime = model.MetaData.Platform;
         this.modules = [new acuity.Graph(metadata, model, data, quantization)];
     }

+ 9 - 1
source/keras.js

@@ -1455,7 +1455,15 @@ tfjs.Container = class {
             throw new tfjs.Error('File format is not TensorFlow.js layers-model.');
         }
         const modelTopology = obj.modelTopology;
-        this.format = `TensorFlow.js ${obj.format ? obj.format : `Keras${modelTopology.keras_version ? (` v${modelTopology.keras_version}`) : ''}`}`;
+        if (obj.format) {
+            this.format = `TensorFlow.js ${obj.format}`;
+        } else if (modelTopology.keras_version) {
+            const match = modelTopology.keras_version.match(/^(.+)\s+(\d.*)$/);
+            const version = match ? `${match[1]} v${match[2]}` : `v${modelTopology.keras_version}`;
+            this.format = `TensorFlow.js Keras ${version}`;
+        } else {
+            this.format = 'TensorFlow.js Keras';
+        }
         this.producer = obj.convertedBy || obj.generatedBy || '';
         this.backend = modelTopology.backend || '';
         const manifests = obj.weightsManifest;

+ 5 - 0
source/onnx.js

@@ -52,6 +52,11 @@ onnx.Model = class {
         this._modules = [];
         this._format = target.format;
         this._producer = model.producer_name && model.producer_name.length > 0 ? model.producer_name + (model.producer_version && model.producer_version.length > 0 ? ` ${model.producer_version}` : '') : null;
+        if (this._producer && this._producer.startsWith('CatBoost Git info:')) {
+            const version = this._producer.match(/Branch: tags\/v([\d.]+)/);
+            const commit = this._producer.match(/Commit: ([a-f0-9]{7})/);
+            this._producer = `CatBoost${version ? ` v${version[1]}` : ''}${commit ? `+${commit[1]}` : ''}`;
+        }
         this._domain = model.domain;
         this._version = typeof model.model_version === 'number' || typeof model.model_version === 'bigint' ? model.model_version.toString() : '';
         this._description = model.doc_string;

+ 11 - 2
source/view.js

@@ -6479,7 +6479,7 @@ view.ModelFactoryService = class {
         try {
             await this._openSignature(context);
             const content = new view.Context(context);
-            const model = await this._openContext(content);
+            let model = await this._openContext(content);
             if (!model) {
                 const check = (obj) => {
                     if (obj instanceof Error) {
@@ -6504,7 +6504,16 @@ view.ModelFactoryService = class {
                 if (!entryContext) {
                     await this._unsupported(content);
                 }
-                return this._openContext(entryContext);
+                model = await this._openContext(entryContext);
+            }
+            if (!model.format || typeof model.format !== 'string' || model.format.length === 0) {
+                throw new view.Error('Invalid model format name.');
+            }
+            if (!/^[a-zA-Z][a-zA-Z0-9-.]*( [a-zA-Z][a-zA-Z0-9-.]*)*( v\d+(\.\d+)*(b\d+)?([.+-][a-zA-Z0-9]+)?)?$/.test(model.format) || model.format.includes('undefined')) {
+                throw new view.Error(`Invalid model format name '${model.format}'.`);
+            }
+            if (model.producer && /[^\x20-\x7E\u00C0-\u00FF\u0370-\u03FF]/.test(model.producer)) {
+                throw new view.Error(`Invalid model producer name '${model.producer}'.`);
             }
             return model;
         } catch (error) {

+ 10 - 1
test/models.json

@@ -4346,6 +4346,15 @@
     "tags":     "validation",
     "link":     "https://github.com/lutzroeder/netron/issues/767"
   },
+  {
+    "type":     "onnx",
+    "target":   "catboost.onnx",
+    "source":   "https://github.com/user-attachments/files/24994029/catboost.onnx.zip[catboost.onnx]",
+    "format":   "ONNX v3",
+    "producer": "CatBoost v1.2.8+0bcf252",
+    "tags":     "validation",
+    "link":     "https://github.com/lutzroeder/netron/issues/6"
+  },
   {
     "type":     "onnx",
     "target":   "Clara_DenseNet_dynamo.onnx.zip",
@@ -8306,7 +8315,7 @@
     "type":     "tfjs",
     "target":   "iamgeai2/metadata.json,iamgeai2/model.json,iamgeai2/weights.bin",
     "source":   "https://github.com/lutzroeder/netron/files/7800435/iamgeai2.zip[metadata.json,model.json,weights.bin] ",
-    "format":   "TensorFlow.js Keras vtfjs-layers 1.3.1",
+    "format":   "TensorFlow.js Keras tfjs-layers v1.3.1",
     "link":     "https://github.com/lutzroeder/netron/issues/294"
   },
   {