Lutz Roeder 4 лет назад
Родитель
Сommit
9b757a4428
1 измененных файлов с 16 добавлено и 12 удалено
  1. 16 12
      source/onnx.js

+ 16 - 12
source/onnx.js

@@ -17,7 +17,7 @@ onnx.ModelFactory = class {
             });
         };
         switch (this._format(context)) {
-            case 'pbtxt':
+            case 'onnx.pbtxt.ModelProto':
                 return context.require('./onnx-proto').then(() => {
                     try {
                         onnx.proto = protobuf.get('onnx').onnx;
@@ -32,7 +32,7 @@ onnx.ModelFactory = class {
                         throw new onnx.Error('File text format is not onnx.ModelProto (' + message.replace(/\.$/, '') + ').');
                     }
                 });
-            case 'pb-tensor':
+            case 'onnx.pb.TensorProto':
                 return context.require('./onnx-proto').then(() => {
                     // TensorProto
                     // input_0.pb, output_0.pb
@@ -60,7 +60,7 @@ onnx.ModelFactory = class {
                         throw new onnx.Error('File format is not onnx.TensorProto (' + message.replace(/\.$/, '') + ').');
                     }
                 });
-            case 'pb-graph':
+            case 'onnx.pb.GraphProto':
                 return context.require('./onnx-proto').then(() => {
                     // GraphProto
                     try {
@@ -77,7 +77,7 @@ onnx.ModelFactory = class {
                         throw new onnx.Error('File format is not onnx.GraphProto (' + message.replace(/\.$/, '') + ').');
                     }
                 });
-            case 'pb-model':
+            case 'onnx.pb.ModelProto':
                 return context.require('./onnx-proto').then(() => {
                     // ModelProto
                     try {
@@ -93,7 +93,7 @@ onnx.ModelFactory = class {
                         throw new onnx.Error('File format is not onnx.ModelProto (' + message.replace(/\.$/, '') + ').');
                     }
                 });
-            case 'ort':
+            case 'onnx.flatbuffers': {
                 return context.require('./ort-schema').then((/* schema */) => {
                     try {
                         onnx.schema = flatbuffers.get('ort').onnxruntime.experimental.fbs;
@@ -143,6 +143,10 @@ onnx.ModelFactory = class {
                         throw new onnx.Error('File format is not ort.Model (' + message.replace(/\.$/, '') + ').');
                     }
                 });
+            }
+            default: {
+                throw new onnx.Error("Unknown ONNX format '" + this._format(context) + "'.");
+            }
         }
     }
 
@@ -197,7 +201,7 @@ onnx.ModelFactory = class {
                 if (tags.get(1) === 0 && tags.get(2) === 0 && tags.get(9) === 2) {
                     const schema = [[1,0],[2,0],[4,2],[5,2],[7,2],[8,2],[9,2]];
                     if (schema.every((pair) => !tags.has(pair[0]) || tags.get(pair[0]) === pair[1])) {
-                        return 'pb-tensor';
+                        return 'onnx.pb.TensorProto';
                     }
                 }
                 // GraphProto
@@ -226,7 +230,7 @@ onnx.ModelFactory = class {
                         if (nodeBuffer) {
                             const nameBuffer = decode(nodeBuffer, 4);
                             if (nameBuffer && nameBuffer.every((c) => c > 0x20 && c < 0x7f)) {
-                                return 'pb-graph';
+                                return 'onnx.pb.GraphProto';
                             }
                         }
                     }
@@ -235,7 +239,7 @@ onnx.ModelFactory = class {
                 if (tags.get(7) === 2) {
                     const schema = [[1,0],[2,2],[3,2],[4,2][5,0],[6,2],[7,2],[8,2],[14,2],[20,2]];
                     if (schema.every((pair) => !tags.has(pair[0]) || tags.get(pair[0]) === pair[1])) {
-                        return 'pb-model';
+                        return 'onnx.pb.ModelProto';
                     }
                 }
             }
@@ -258,19 +262,19 @@ onnx.ModelFactory = class {
                     'WinMLTools'
                 ];
                 if (producers.some((producer) => Array.from(producer).every((ch, index) => index + 4 < buffer.length && ch.charCodeAt(0) === buffer[index + 4]))) {
-                    return 'pb-model';
+                    return 'onnx.pb.ModelProto';
                 }
             }
         }
         tags = context.tags('pbtxt');
         if (tags.has('ir_version')) {
-            return 'pbtxt';
+            return 'onnx.pbtxt.ModelProto';
         }
         if (tags.has('graph') && extension !== 'model') {
-            return 'pbtxt';
+            return 'onnx.pbtxt.ModelProto';
         }
         if (context.tags('flatbuffers').get('file_identifier') === 'ORTM') {
-            return 'ort';
+            return 'onnx.flatbuffers';
         }
         return '';
     }