소스 검색

Add ONNX text prototype (#884)

Lutz Roeder 4 년 전
부모
커밋
c6454a144a
2개의 변경된 파일607개의 추가작업 그리고 1개의 파일을 삭제
  1. 600 1
      source/onnx.js
  2. 7 0
      test/models.json

+ 600 - 1
source/onnx.js

@@ -2,6 +2,7 @@
 var onnx = onnx || {};
 var protobuf = protobuf || require('./protobuf');
 var flatbuffers = flatbuffers || require('./flatbuffers');
+var text = text || require('./text');
 
 onnx.ModelFactory = class {
 
@@ -134,6 +135,9 @@ onnx.ModelFactory = class {
         if (context.tags('flatbuffers').get('file_identifier') === 'ORTM') {
             return 'onnx.flatbuffers';
         }
+        if (onnx.TextReader.open(stream)) {
+            return 'onnx.text';
+        }
         return undefined;
     }
 
@@ -272,6 +276,22 @@ onnx.ModelFactory = class {
                     }
                 });
             }
+            case 'onnx.text': {
+                return context.require('./onnx-proto').then(() => {
+                    try {
+                        onnx.proto = protobuf.get('onnx').onnx;
+                        const stream = context.stream;
+                        const reader = onnx.TextReader.open(stream);
+                        const model = reader.read();
+                        const format = 'ONNX Text' + (model.ir_version ? ' v' + model.ir_version.toString() : '');
+                        return open(model, format);
+                    }
+                    catch (error) {
+                        const message = error && error.message ? error.message : error.toString();
+                        throw new onnx.Error('File format is not onnx.ModelProto (' + message.replace(/\.$/, '') + ').');
+                    }
+                });
+            }
             default: {
                 throw new onnx.Error("Unknown ONNX format '" + match + "'.");
             }
@@ -295,7 +315,7 @@ onnx.Model = class {
         if (model.opset_import && model.opset_import.length > 0) {
             for (const opset_import of model.opset_import) {
                 const domain = opset_import.domain || 'ai.onnx';
-                const version = opset_import.version ? opset_import.version.toNumber() : 0;
+                const version = opset_import.version ? typeof opset_import.version === 'number' ? opset_import.version: opset_import.version.toNumber() : 0;
                 if (!imports.has(domain) || imports.get(domain) > version) {
                     imports.set(domain, version);
                 }
@@ -1730,6 +1750,585 @@ onnx.GraphContext = class {
     }
 };
 
+onnx.TextReader = class {
+
+    static open(data) {
+        try {
+            const reader = text.Reader.open(data);
+            const lines = [];
+            for (let i = 0; i < 32; i++) {
+                const line = reader.read();
+                if (line === undefined) {
+                    break;
+                }
+                lines.push(line);
+            }
+            const content = lines.join('\n');
+            if (/^\s*<\s*ir_version\s*:/m.exec(content)) {
+                return new onnx.TextReader(data);
+            }
+        }
+        catch (err) {
+            // continue regardless of error
+        }
+        return null;
+    }
+
+    constructor(data) {
+        this._data = data;
+        this._dataTypes = new Map([
+            [ 'float', 1 ], [ 'uint8', 2 ], [ 'int8', 3 ], [ 'uint16', 4 ],
+            [ 'int16', 5 ], [ 'int32', 6 ], [ 'int64', 7 ], [ 'string', 8 ],
+            [ 'bool', 9 ], [ 'float16', 10 ], [ 'double', 11 ], [ 'uint32', 12 ],
+            [ 'uint64', 13 ], [ 'complex64', 14 ], [ 'complex128', 15 ], [ 'bfloat16', 16 ]
+        ]);
+        this._attributeTypes = new Map([
+            [ 'float', 1 ], [ 'int', 2 ], [ 'string', 3 ],
+            [ 'tensor', 4 ], [ 'graph', 5 ], [ 'sparse_tensor', 11 ], [ 'type_proto', 13 ],
+            [ 'floats', 6 ], [ 'ints', 7 ], [ 'strings', 8 ],
+            [ 'tensors', 9 ], [ 'graphs', 10 ], [ 'sparse_tensors', 12 ], [ 'type_protos', 14 ]
+        ]);
+    }
+
+    read() {
+        const decoder = text.Decoder.open(this._data);
+        this._decoder = decoder;
+        this._position = 0;
+        this._char = decoder.decode();
+        return this._model();
+    }
+
+    _seek(position) {
+        this._decoder.position = position;
+        this._char = '';
+        this._next();
+    }
+
+    _model() {
+        this._whitespace();
+        const model = new onnx.proto.ModelProto();
+        if (this._match('<')) {
+            do {
+                const keyword = this._identifier();
+                this._expect(':');
+                switch (keyword) {
+                    case 'ir_version':
+                    case 'model_version':
+                        model[keyword] = this._int64();
+                        break;
+                    case 'opset_import':
+                        model[keyword] = this._operatorSetId();
+                        break;
+                    case 'producer_name':
+                    case 'producer_version':
+                    case 'domain':
+                    case 'doc_string':
+                        model[keyword] = this._string();
+                        break;
+                    case 'metadata_props':
+                        this._expect('[');
+                        if (!this._match(']')) {
+                            do {
+                                const entry = new onnx.proto.StringStringEntryProto();
+                                entry.key = this._string();
+                                this._expect(':');
+                                entry.value = this._string();
+                                model.metadata_props.push(entry);
+                            } while (this._match(','));
+                            this._expect(']');
+                        }
+                        break;
+                    default:
+                        this._throw("Unknown keyword '" + keyword + "'.");
+                        break;
+                }
+            } while (this._match(','));
+            this._expect('>');
+        }
+        model.graph = this._graph();
+        this._whitespace();
+        while (this._char !== undefined) {
+            const func = this._function();
+            if (func) {
+                model.functions.push(func);
+            }
+            this._whitespace();
+        }
+        return model;
+    }
+
+    _graph() {
+        const graph = new onnx.proto.GraphProto();
+        graph.name = this._identifier();
+        if (this._match('(')) {
+            if (!this._match(')')) {
+                do {
+                    const valueInfo = this._valueInfo();
+                    if (this._match('=')) {
+                        this._throw('Initializer not implemented.');
+                    }
+                    graph.input.push(valueInfo);
+                }
+                while (this._match(','));
+                this._expect(')');
+            }
+        }
+        this._expect('=>');
+        graph.output = this._valueInfoList();
+        if (this._match('<')) {
+            if (!this._match('>')) {
+                do {
+                    const valueInfo = this._valueInfo();
+                    if (this._match('=')) {
+                        this._throw('Initializer not implemented.');
+                    }
+                    else {
+                        graph.value_info.push(valueInfo);
+                    }
+                }
+                while (this._match(','));
+                this._expect('>');
+            }
+        }
+        graph.node = this._nodeList();
+        return graph;
+    }
+
+    _nodeList() {
+        const list = [];
+        this._expect('{');
+        while (!this._match('}')) {
+            list.push(this._node());
+        }
+        return list;
+    }
+
+    _node() {
+        const node = new onnx.proto.NodeProto();
+        node.output = this._identifierList();
+        this._expect('=');
+        let identifier = this._identifier();
+        let domain = '';
+        while (this._match('.')) {
+            if (domain) {
+                domain += '.';
+            }
+            domain += identifier;
+            identifier = this._identifier();
+        }
+        node.domain = domain;
+        node.op_type = identifier;
+        node.attribute = this._attributeList();
+        this._expect('(');
+        node.input = this._identifierList();
+        this._expect(')');
+        if (!node.attribute || node.attribute.length === 0) {
+            node.attribute = this._attributeList();
+        }
+        return node;
+    }
+
+    _attributeList() {
+        const list = [];
+        if (this._match('<')) {
+            do {
+                list.push(this._attribute());
+            }
+            while (this._match(','));
+            this._expect('>');
+        }
+        return list;
+    }
+
+    _attribute() {
+        const attribute = new onnx.proto.AttributeProto();
+        attribute.name = this._identifier();
+        if (this._match(':')) {
+            const type = this._identifier();
+            if (!this._attributeTypes.has(type)) {
+                this._throw("Unexpected attribute type '" + type + "'.");
+            }
+            attribute.type = this._attributeTypes.get(type);
+        }
+        this._expect('=');
+        if (this._match('[')) {
+            const list = [];
+            do {
+                list.push(this._literal());
+            }
+            while (this._match(','));
+            this._expect(']');
+            if (list.every((value) => typeof value === 'string')) {
+                attribute.type = onnx.AttributeType.STRINGS;
+                attribute.strings = list;
+            }
+            else if (list.every((value) => typeof value === 'number' && Number.isInteger(value))) {
+                attribute.type = onnx.AttributeType.INTS;
+                attribute.ints = list;
+            }
+            else if (list.every((value) => typeof value === 'number')) {
+                attribute.type = onnx.AttributeType.FLOATS;
+                attribute.floats = list;
+            }
+            else {
+                this._throw("Unexpected value '" + JSON.stringify(list) + "'.");
+            }
+        }
+        else {
+            if ((this._char >= 'a' && this._char <= 'z') || (this._char >= 'A' && this._char <= 'Z') || this._char === '_') {
+                const identifier = this._identifier();
+                if (this._dataTypes.has(identifier)) {
+                    attribute.type = onnx.AttributeType.TENSOR;
+                    attribute.t = this._tensor(identifier);
+                }
+                else {
+                    attribute.type = onnx.AttributeType.GRAPH;
+                    attribute.g = this._graph();
+                }
+            }
+            else if (this._match('@')) {
+                attribute.ref_attr_name = this._identifier();
+            }
+            else {
+                const value = this._literal();
+                switch (typeof value) {
+                    case 'number':
+                        if (Number.isInteger(value)) {
+                            attribute.type = onnx.AttributeType.INT;
+                            attribute.i = value;
+                        }
+                        else {
+                            attribute.type = onnx.AttributeType.FLOAT;
+                            attribute.f = value;
+                        }
+                        break;
+                    case 'string':
+                        attribute.type = onnx.AttributeType.STRING;
+                        attribute.s = value;
+                        break;
+                    default: {
+                        this._throw("Unexpected value '" + JSON.stringify(value) + "'.");
+                    }
+                }
+            }
+        }
+        return attribute;
+    }
+
+    _valueInfoList() {
+        const list = [];
+        this._expect('(');
+        if (!this._match(')')) {
+            do {
+                list.push(this._valueInfo());
+            } while (this._match(','));
+            this._expect(')');
+        }
+        return list;
+    }
+
+    _valueInfo() {
+        const valueInfo = new onnx.proto.ValueInfoProto();
+        let identifier = this._identifier();
+        if (this._dataTypes.has(identifier)) {
+            valueInfo.type = this._type(this._dataTypes.get(identifier));
+            identifier = this._identifier();
+        }
+        valueInfo.name = identifier;
+        return valueInfo;
+    }
+
+    _type(elem_type) {
+        const type = new onnx.proto.TypeProto();
+        type.elem_type = elem_type;
+        if (this._match('[')) {
+            if (!this._match(']')) {
+                type.shape = this._shape();
+                this._expect(']');
+            }
+        }
+        return type;
+    }
+
+    _shape() {
+        const shape = new onnx.proto.TensorShapeProto();
+        do {
+            const dimension = new onnx.proto.TensorShapeProto.Dimension();
+            if (!this._match('?')) {
+                const identifier = this._identifier(true);
+                if (identifier) {
+                    dimension.dim_param = identifier;
+                }
+                else {
+                    dimension.dim_value = this._int64();
+                }
+            }
+            shape.dim.push(dimension);
+        }
+        while (this._match(','));
+        return shape;
+    }
+
+    _tensor() {
+        const tensor = new onnx.proto.TensorProto();
+        this._expect('=');
+        this._expect('{');
+        if (!this._match('}')) {
+            do {
+                this._next();
+            } while (!this._match('}'));
+        }
+        return tensor;
+    }
+
+    _function() {
+        const func = new onnx.proto.FunctionProto();
+        if (this._match('<')) {
+            do {
+                const keyword = this._identifier();
+                this._expect(':');
+                switch (keyword) {
+                    case 'opset_import':
+                        func[keyword] = this._operatorSetId();
+                        break;
+                    case 'domain':
+                    case 'doc_string':
+                        func[keyword] = this._string();
+                        break;
+                    default:
+                        this._throw("Unknown keyword '" + keyword + "'.");
+                        break;
+                }
+            }
+            while (this._match(','));
+            this._expect('>');
+        }
+        func.name = this._identifier();
+        if (this._match('<')) {
+            func.attribute = this._identifierList();
+            this._expect('>');
+        }
+        if (this._match('(')) {
+            func.input = this._identifierList();
+            this._expect(')');
+        }
+        this._expect('=>');
+        if (this._match('(')) {
+            func.output = this._identifierList();
+            this._expect(')');
+        }
+        func.node = this._nodeList();
+        return func;
+    }
+
+    _identifierList() {
+        const list = [];
+        const identifier = this._identifier(true);
+        if (identifier) {
+            list.push(identifier);
+            while (this._match(',')) {
+                list.push(this._identifier());
+            }
+        }
+        return list;
+    }
+
+    _identifier(optional) {
+        this._whitespace();
+        const value = [];
+        if ((this._char >= 'a' && this._char <= 'z') || (this._char >= 'A' && this._char <= 'Z')) {
+            value.push(this._char);
+            this._next();
+            while ((this._char >= 'a' && this._char <= 'z') || (this._char >= 'A' && this._char <= 'Z') || (this._char >= '0' && this._char <= '9') || this._char === '_') {
+                value.push(this._char);
+                this._next();
+            }
+        }
+        if (optional !== true && value.length == 0) {
+            this._throw('Identifier expected.');
+        }
+        return value.join('');
+    }
+
+    _literal() {
+        this._whitespace();
+        let decimal_point = false;
+        if (this._char === '"') {
+            const value = [];
+            this._next();
+            while (this._char !== undefined && this._char !== '"') {
+                value.push(this._char);
+                this._next();
+            }
+            if (this._char !== undefined) {
+                this._next();
+            }
+            return value.join('');
+        }
+        else if ((this._char >= '0' && this._char <= '9') || this._char === '-') {
+            const value = [ this._char ];
+            this._next();
+            while ((this._char >= '0' && this._char <= '9') || this._char === '.') {
+                if (this._char === '.') {
+                    if (decimal_point) {
+                        this._throw();
+                    }
+                    decimal_point = true;
+                }
+                value.push(this._char);
+                this._next();
+            }
+            if (value.length === 0) {
+                this._throw('Value expected.');
+            }
+            if (this._char === 'e' || this._char === 'E') {
+                decimal_point = true;
+                value.push(this._char);
+                this._next();
+                if (this._char === '+' || this._char === '-') {
+                    value.push(this._char);
+                    this._next();
+                }
+                while ((this._char >= '0' && this._char <= '9')) {
+                    value.push(this._char);
+                    this._next();
+                }
+            }
+            return decimal_point ? Number.parseFloat(value.join('')) : Number.parseInt(value.join(''), 10);
+        }
+        return undefined;
+    }
+
+    _int64() {
+        const value = this._literal();
+        if (!Number.isInteger(value)) {
+            this._throw('Integer value expected.');
+        }
+        return value;
+    }
+
+    _string() {
+        const value = this._literal();
+        if (typeof value !== 'string') {
+            this._throw('String value expected.');
+        }
+        return value;
+    }
+
+    _operatorSetId() {
+        const list = [];
+        this._expect('[');
+        if (!this._match(']')) {
+            do {
+                const value = new onnx.proto.OperatorSetIdProto();
+                value.domain = this._string();
+                this._expect(':');
+                value.version = this._int64();
+                list.push(value);
+            }
+            while (this._match(','));
+            this._expect(']');
+        }
+        return list;
+    }
+
+    _match(value) {
+        this._whitespace();
+        if (this._char !== value[0]) {
+            return false;
+        }
+        if (value.length === 1) {
+            this._next();
+            return true;
+        }
+        const position = this._position;
+        for (let i = 0; i < value.length; i++) {
+            if (this._char !== value[i]) {
+                this._seek(position);
+                return false;
+            }
+            this._next();
+        }
+        return true;
+    }
+
+    _expect(value) {
+        if (!this._match(value)) {
+            this._unexpected();
+        }
+        return true;
+    }
+
+    _whitespace() {
+        for (;;) {
+            while (this._char === ' ' || this._char === '\n' || this._char === '\r' || this._char === '\t') {
+                this._next();
+            }
+            if (this._char === undefined || this._char !== '#') {
+                break;
+            }
+            while (this._char !== undefined && this._char !== '\n') {
+                this._next();
+            }
+        }
+    }
+
+    _next() {
+        if (this._char === undefined) {
+            this._unexpected();
+        }
+        this._position = this._decoder.position;
+        this._char = this._decoder.decode();
+    }
+
+    _unexpected() {
+        let c = this._char;
+        if (c === undefined) {
+            throw new onnx.Error('Unexpected end of input.');
+        }
+        else if (c === '"') {
+            c = 'string';
+        }
+        else if ((c >= '0' && c <= '9') || c === '-') {
+            c = 'number';
+        }
+        else {
+            if (c < ' ' || c > '\x7F') {
+                const name = Object.keys(this._escape).filter((key) => this._escape[key] === c);
+                c = (name.length === 1) ? '\\' + name : '\\u' + ('000' + c.charCodeAt(0).toString(16)).slice(-4);
+            }
+            c = "token '" + c + "'";
+        }
+        this._throw('Unexpected ' + c);
+    }
+
+    _throw(message) {
+        throw new onnx.Error(message.replace(/\.$/, '') + this._location());
+    }
+
+    _location() {
+        let line = 1;
+        let column = 1;
+        this._decoder.position = 0;
+        let c;
+        do {
+            if (this._decoder.position === this._position) {
+                return ' at ' + line.toString() + ':' + column.toString() + '.';
+            }
+            c = this._decoder.decode();
+            if (c === '\n') {
+                line++;
+                column = 1;
+            }
+            else {
+                column++;
+            }
+        }
+        while (c !== undefined);
+        return ' at ' + line.toString() + ':' + column.toString() + '.';
+    }
+};
+
 onnx.Error = class extends Error {
 
     constructor(message) {

+ 7 - 0
test/models.json

@@ -3264,6 +3264,13 @@
     "format": "ONNX v4",
     "link":   "https://github.com/lutzroeder/netron/issues/532"
   },
+  {
+    "type":   "onnx",
+    "target": "example.txt",
+    "source": "https://github.com/lutzroeder/netron/files/8174755/example.txt.zip[example.txt]",
+    "format": "ONNX Text v7",
+    "link":   "https://github.com/lutzroeder/netron/issues/532"
+  },
   {
     "type":   "onnx",
     "target": "Exermote.onnx",