Lutz Roeder 1 éve
szülő
commit
78be4a9536
11 módosított fájl, 145 hozzáadás és 141 törlés
  1. 5 8
      source/coreml.js
  2. 31 28
      source/darknet.js
  3. 1 1
      source/dlc.js
  4. 12 5
      source/mlir.js
  5. 26 28
      source/ncnn.js
  6. 3 6
      source/nnabla.js
  7. 17 19
      source/nnef.js
  8. 5 8
      source/onnx.js
  9. 15 18
      source/text.js
  10. 14 16
      source/tnn.js
  11. 16 4
      source/view.js

+ 5 - 8
source/coreml.js

@@ -1,6 +1,5 @@
 
 import * as base from './base.js';
-import * as text from './text.js';
 
 const coreml = {};
 
@@ -53,13 +52,11 @@ coreml.ModelFactory = class {
         }
         if (identifier === 'model.mil') {
             try {
-                const reader = text.Reader.open(context.stream, 2048);
-                const signature = reader.read();
-                if (signature !== undefined) {
-                    if (signature.trim().startsWith('program')) {
-                        context.type = 'coreml.mil';
-                        return;
-                    }
+                const reader = context.read('text', 2048);
+                const signature = reader.read('\n');
+                if (signature && signature.trim().startsWith('program')) {
+                    context.type = 'coreml.mil';
+                    return;
                 }
             } catch {
                 // continue regardless of error

+ 31 - 28
source/darknet.js

@@ -1,6 +1,4 @@
 
-import * as text from './text.js';
-
 const darknet = {};
 
 darknet.ModelFactory = class {
@@ -16,19 +14,21 @@ darknet.ModelFactory = class {
             }
             return;
         }
-        try {
-            const reader = text.Reader.open(context.stream, 65536);
-            for (let line = reader.read(); line !== undefined; line = reader.read()) {
-                const content = line.trim();
-                if (content.length > 0 && !content.startsWith('#')) {
-                    if (content.startsWith('[') && content.endsWith(']')) {
-                        context.type = 'darknet.model';
-                    }
-                    return;
+        const reader = context.read('text', 65536);
+        if (reader) {
+            try {
+                for (let line = reader.read('\n'); line !== undefined; line = reader.read('\n')) {
+                    const content = line.trim();
+                    if (content.length > 0 && !content.startsWith('#')) {
+                        if (content.startsWith('[') && content.endsWith(']')) {
+                            context.type = 'darknet.model';
+                        }
+                        return;
+                    }
                 }
+            } catch {
+                // continue regardless of error
             }
-        } catch {
-            // continue regardless of error
         }
     }
 
@@ -43,19 +43,22 @@ darknet.ModelFactory = class {
                 const weights = context.target;
                 const name = `${basename}.cfg`;
                 const content = await context.fetch(name);
-                const reader = new darknet.Reader(content.stream, content.identifier);
-                return new darknet.Model(metadata, reader, weights);
+                const reader = content.read('text');
+                const configuration = new darknet.Configuration(reader, content.identifier);
+                return new darknet.Model(metadata, configuration, weights);
             }
             case 'darknet.model': {
                 try {
                     const name = `${basename}.weights`;
                     const content = await context.fetch(name);
                     const weights = darknet.Weights.open(content);
-                    const reader = new darknet.Reader(context.stream, context.identifier);
-                    return new darknet.Model(metadata, reader, weights);
+                    const reader = context.read('text');
+                    const configuration = new darknet.Configuration(reader, context.identifier);
+                    return new darknet.Model(metadata, configuration, weights);
                 } catch {
-                    const reader = new darknet.Reader(context.stream, context.identifier);
-                    return new darknet.Model(metadata, reader, null);
+                    const reader = context.read('text');
+                    const configuration = new darknet.Configuration(reader, context.identifier);
+                    return new darknet.Model(metadata, configuration, null);
                 }
             }
             default: {
@@ -67,20 +70,20 @@ darknet.ModelFactory = class {
 
 darknet.Model = class {
 
-    constructor(metadata, reader, weights) {
+    constructor(metadata, configuration, weights) {
         this.format = 'Darknet';
-        this.graphs = [new darknet.Graph(metadata, reader, weights)];
+        this.graphs = [new darknet.Graph(metadata, configuration, weights)];
     }
 };
 
 darknet.Graph = class {
 
-    constructor(metadata, reader, weights) {
+    constructor(metadata, configuration, weights) {
         this.inputs = [];
         this.outputs = [];
         this.nodes = [];
         const params = {};
-        const sections = reader.read();
+        const sections = configuration.read();
         const globals = new Map();
         const net = sections.shift();
         const option_find_int = (options, key, defaultValue) => {
@@ -872,10 +875,10 @@ darknet.TensorShape = class {
     }
 };
 
-darknet.Reader = class {
+darknet.Configuration = class {
 
-    constructor(stream, identifier) {
-        this.stream = stream;
+    constructor(reader, identifier) {
+        this.reader = reader;
         this.identifier = identifier;
     }
 
@@ -883,10 +886,10 @@ darknet.Reader = class {
         // read_cfg
         const sections = [];
         let section = null;
-        const reader = text.Reader.open(this.stream);
+        const reader = this.reader;
         let lineNumber = 0;
         const setup = /^setup.*\.cfg$/.test(this.identifier);
-        for (let content = reader.read(); content !== undefined; content = reader.read()) {
+        for (let content = reader.read('\n'); content !== undefined; content = reader.read('\n')) {
             lineNumber++;
             const line = content.replace(/\s/g, '');
             if (line.length > 0) {

+ 1 - 1
source/dlc.js

@@ -350,7 +350,7 @@ dlc.Container = class {
             delete this._metadata;
             const reader = text.Reader.open(stream);
             for (;;) {
-                const line = reader.read();
+                const line = reader.read('\n');
                 if (line === undefined) {
                     break;
                 }

+ 12 - 5
source/mlir.js

@@ -2,19 +2,26 @@
 // Experimental
 // contributor @tucan9389
 
-import * as text from './text.js';
-
 const mlir = {};
 
 mlir.ModelFactory = class {
 
     match(context) {
-        context.type = 'mlir';
+        try {
+            const reader = context.read('text', 0x10000);
+            for (let line = reader.read('\n'); line !== undefined; line = reader.read('\n')) {
+                if (/module\s+(\w+\s+)?{/.test(line) || /tensor<\w+>/.test(line)) {
+                    context.type = 'mlir';
+                    return;
+                }
+            }
+        } catch {
+            // continue regardless of error
+        }
     }
 
     async open(context) {
-        const stream = context.stream;
-        const decoder = text.Decoder.open(stream);
+        const decoder = context.read('text.decoder');
         const parser = new mlir.Parser(decoder);
         const obj = parser.read();
         return new mlir.Model(obj);

+ 26 - 28
source/ncnn.js

@@ -1,6 +1,5 @@
 
 import * as base from './base.js';
-import * as text from './text.js';
 
 const ncnn = {};
 
@@ -22,21 +21,23 @@ ncnn.ModelFactory = class {
                 }
             }
         } else if (identifier.endsWith('.param') || identifier.endsWith('.cfg.ncnn')) {
-            try {
-                const reader = text.Reader.open(context.stream, 2048);
-                const signature = reader.read();
-                if (signature !== undefined) {
-                    if (signature.trim() === '7767517') {
-                        context.type = 'ncnn.model';
-                        return;
-                    }
-                    const header = signature.trim().split(' ');
-                    if (header.length === 2 && header.every((value) => value >>> 0 === parseFloat(value))) {
-                        context.type = 'ncnn.model';
+            const reader = context.read('text', 0x10000);
+            if (reader) {
+                try {
+                    const signature = reader.read('\n');
+                    if (signature !== undefined) {
+                        if (signature.trim() === '7767517') {
+                            context.type = 'ncnn.model';
+                            return;
+                        }
+                        const header = signature.trim().split(' ');
+                        if (header.length === 2 && header.every((value) => value >>> 0 === parseFloat(value))) {
+                            context.type = 'ncnn.model';
+                        }
                     }
+                } catch {
+                    // continue regardless of error
                 }
-            } catch {
-                // continue regardless of error
             }
         } else if (identifier.endsWith('.bin') || identifier.endsWith('.weights.ncnn')) {
             const stream = context.stream;
@@ -80,8 +81,8 @@ ncnn.ModelFactory = class {
             const reader = new ncnn.BinaryParamReader(param);
             return new ncnn.Model(metadata, reader, bin);
         };
-        const openText = (param, bin) => {
-            const reader = new ncnn.TextParamReader(param);
+        const openText = (reader, bin) => {
+            reader = new ncnn.TextParamReader(reader);
             return new ncnn.Model(metadata, reader, bin);
         };
         const identifier = context.identifier.toLowerCase();
@@ -93,12 +94,13 @@ ncnn.ModelFactory = class {
                 } else if (identifier.endsWith('.cfg.ncnn')) {
                     bin = `${context.identifier.substring(0, context.identifier.length - 9)}.weights.ncnn`;
                 }
+                const reader = context.read('text');
                 try {
                     const content = await context.fetch(bin);
                     const buffer = content.stream.peek();
-                    return openText(context.stream.peek(), buffer);
+                    return openText(reader, buffer);
                 } catch {
-                    return openText(context.stream.peek(), null);
+                    return openText(reader, null);
                 }
             }
             case 'ncnn.model.bin': {
@@ -120,8 +122,8 @@ ncnn.ModelFactory = class {
                 }
                 try {
                     const content = await context.fetch(file);
-                    const buffer = content.stream.peek();
-                    return openText(buffer, context.stream.peek());
+                    const reader = content.read('text');
+                    return openText(reader, context.stream.peek());
                 } catch {
                     const content = await context.fetch(`${file}.bin`);
                     const buffer = content.stream.peek();
@@ -634,15 +636,11 @@ ncnn.Utility = class {
 
 ncnn.TextParamReader = class {
 
-    constructor(buffer) {
-        const reader = text.Reader.open(buffer);
+    constructor(reader) {
         const lines = [];
-        for (;;) {
-            const line = reader.read();
-            if (line === undefined) {
-                break;
-            }
-            lines.push(line.trim());
+        for (let line = reader.read('\n'); line !== undefined; line = reader.read('\n')) {
+            line = line.trim();
+            lines.push(line);
         }
         const signature = lines.shift();
         const header = (signature === '7767517' ? lines.shift() : signature).split(' ');

+ 3 - 6
source/nnabla.js

@@ -1,6 +1,4 @@
 
-import * as text from './text.js';
-
 const nnabla = {};
 
 nnabla.ModelFactory = class {
@@ -29,10 +27,9 @@ nnabla.ModelFactory = class {
                 let version = '';
                 if (contexts.has('nnp_version.txt')) {
                     const context = contexts.get('nnp_version.txt');
-                    const stream = context.stream;
-                    const reader = text.Reader.open(stream);
-                    version = reader.read();
-                    version = version.split('\r').shift();
+                    const reader = context.read('text');
+                    const line = reader.read('\n');
+                    version = line.split('\r').shift();
                 }
                 if (contexts.has('parameter.protobuf')) {
                     const context = contexts.get('parameter.protobuf');

+ 17 - 19
source/nnef.js

@@ -1,6 +1,4 @@
 
-import * as text from './text.js';
-
 const nnef = {};
 
 nnef.ModelFactory = class {
@@ -8,21 +6,26 @@ nnef.ModelFactory = class {
     match(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        const stream = context.stream;
         switch (extension) {
-            case 'nnef':
-                if (nnef.TextReader.open(stream)) {
+            case 'nnef': {
+                const reader = nnef.TextReader.open(context);
+                if (reader) {
                     context.type = 'nnef.graph';
+                    context.target = reader;
                 }
                 break;
-            case 'dat':
+            }
+            case 'dat': {
+                const stream = context.stream;
                 if (stream && stream.length > 2) {
                     const buffer = stream.peek(2);
                     if (buffer[0] === 0x4E && buffer[1] === 0xEF) {
                         context.type = 'nnef.dat';
+                        context.target = stream;
                     }
                 }
                 break;
+            }
             default:
                 break;
         }
@@ -35,8 +38,7 @@ nnef.ModelFactory = class {
     async open(context) {
         switch (context.type) {
             case 'nnef.graph': {
-                const stream = context.stream;
-                const reader = nnef.TextReader.open(stream);
+                const reader = context.target;
                 throw new nnef.Error(`NNEF v${reader.version} support not implemented.`);
             }
             case 'nnef.dat': {
@@ -51,13 +53,13 @@ nnef.ModelFactory = class {
 
 nnef.TextReader = class {
 
-    static open(stream) {
-        const reader = text.Reader.open(stream);
+    static open(context) {
+        const reader = context.read('text', 65536);
         for (let i = 0; i < 32; i++) {
-            const line = reader.read();
+            const line = reader.read('\n');
             const match = /version\s*(\d+\.\d+);/.exec(line);
             if (match) {
-                return new nnef.TextReader(stream, match[1]);
+                return new nnef.TextReader(context, match[1]);
             }
             if (line === undefined) {
                 break;
@@ -66,13 +68,9 @@ nnef.TextReader = class {
         return null;
     }
 
-    constructor(stream, version) {
-        this._stream = stream;
-        this._version = version;
-    }
-
-    get version() {
-        return this._version;
+    constructor(context, version) {
+        this.context = context;
+        this.version = version;
     }
 };
 

+ 5 - 8
source/onnx.js

@@ -1,6 +1,5 @@
 
 import * as protobuf from './protobuf.js';
-import * as text from './text.js';
 
 const onnx = {};
 
@@ -2010,14 +2009,13 @@ onnx.TextReader = class {
     static open(context) {
         try {
             const stream = context.stream;
-            if (stream && stream.length > 0) {
-                const size = Math.min(0x10000, stream.length);
-                const buffer = stream.peek(size);
+            if (stream && stream.length > 2) {
+                const buffer = stream.peek(2);
                 if (buffer[0] < 0x80 || buffer[0] >= 0xFE) {
-                    const reader = text.Reader.open(buffer);
+                    const reader = context.read('text', 0x10000);
                     const lines = [];
                     for (let i = 0; i < 32; i++) {
-                        const line = reader.read();
+                        const line = reader.read('\n');
                         if (line === undefined) {
                             break;
                         }
@@ -2047,8 +2045,7 @@ onnx.TextReader = class {
         onnx.proto = await this._context.require('./onnx-proto');
         onnx.proto = onnx.proto.onnx;
         try {
-            const stream = this._context.stream;
-            this._decoder = text.Decoder.open(stream);
+            this._decoder = this._context.read('text.decoder');
             this._position = 0;
             this._char = this._decoder.decode();
             this.model = this._parseModel();

+ 15 - 18
source/text.js

@@ -193,8 +193,8 @@ text.Decoder.Utf16LE = class {
                 return String.fromCharCode(c);
             }
             if (c >= 0xD800 && c < 0xDBFF) {
-                if (this._position + 1 < this._length) {
-                    const c2 = this._buffer[this._position++] | (this._buffer[this._position++] << 8);
+                if (this.position + 1 < this.length) {
+                    const c2 = this.buffer[this.position++] | (this.buffer[this.position++] << 8);
                     if (c >= 0xDC00 || c < 0xDFFF) {
                         return String.fromCodePoint(0x10000 + ((c & 0x3ff) << 10) + (c2 & 0x3ff));
                     }
@@ -225,8 +225,8 @@ text.Decoder.Utf16BE = class {
                 return String.fromCharCode(c);
             }
             if (c >= 0xD800 && c < 0xDBFF) {
-                if (this._position + 1 < this._length) {
-                    const c2 = (this._buffer[this._position++] << 8) | this._buffer[this._position++];
+                if (this.position + 1 < this.length) {
+                    const c2 = (this.buffer[this.position++] << 8) | this.buffer[this.position++];
                     if (c >= 0xDC00 || c < 0xDFFF) {
                         return String.fromCodePoint(0x10000 + ((c & 0x3ff) << 10) + (c2 & 0x3ff));
                     }
@@ -288,37 +288,34 @@ text.Decoder.Utf32BE = class {
 
 text.Reader = class {
 
-    constructor(data, length) {
-        this._decoder = text.Decoder.open(data);
-        this._position = 0;
-        this._length = length || Number.MAX_SAFE_INTEGER;
+    constructor(data) {
+        this.decoder = text.Decoder.open(data);
+        this.position = 0;
+        this.length = Number.MAX_SAFE_INTEGER;
     }
 
     static open(data, length) {
         return new text.Reader(data, length);
     }
 
-    read() {
-        if (this._position >= this._length) {
+    read(terminal) {
+        if (this.position >= this.length) {
             return undefined;
         }
         let line = '';
         let buffer = null;
         for (;;) {
-            const c = this._decoder.decode();
+            const c = this.decoder.decode();
             if (c === undefined) {
-                this._length = this._position;
-                break;
-            }
-            this._position++;
-            if (this._position > this._length) {
+                this.length = this.position;
                 break;
             }
-            if (c === '\n') {
+            this.position++;
+            if (c === terminal || this.position > this.length) {
                 break;
             }
             line += c;
-            if (line.length >= 32) {
+            if (line.length >= 64) {
                 buffer = buffer || [];
                 buffer.push(line);
                 line = '';

+ 14 - 16
source/tnn.js

@@ -1,6 +1,4 @@
 
-import * as text from './text.js';
-
 const tnn = {};
 
 tnn.ModelFactory = class {
@@ -10,9 +8,8 @@ tnn.ModelFactory = class {
         const stream = context.stream;
         if (stream && identifier.endsWith('.tnnproto')) {
             try {
-                const buffer = stream.peek();
-                const reader = text.Reader.open(buffer, 2048);
-                const content = reader.read();
+                const reader = context.read('text', 0x10000);
+                const content = reader.read('\n');
                 if (content !== undefined) {
                     const line = content.trim();
                     if (line.startsWith('"') && line.endsWith('"')) {
@@ -42,17 +39,19 @@ tnn.ModelFactory = class {
         switch (context.type) {
             case 'tnn.model': {
                 const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnmodel`;
+                const reader = context.read('text');
                 try {
                     const content = await context.fetch(name);
-                    return new tnn.Model(metadata, context, content);
+                    return new tnn.Model(metadata, reader, content);
                 } catch {
-                    return new tnn.Model(metadata, context, null);
+                    return new tnn.Model(metadata, reader, null);
                 }
             }
             case 'tnn.params': {
                 const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnproto`;
                 const content = await context.fetch(name, null);
-                return new tnn.Model(metadata, content, context);
+                const reader = content.read('text');
+                return new tnn.Model(metadata, reader, context);
             }
             default: {
                 throw new tnn.Error(`Unsupported TNN format '${context.type}'.`);
@@ -81,8 +80,8 @@ tnn.Graph = class {
         if (tnnmodel) {
             resources.read(tnnmodel);
         }
-        const reader = new tnn.TextProtoReader(tnnproto.stream);
-        reader.read();
+        const reader = new tnn.TextProtoReader(tnnproto);
+        reader.read('\n');
         const values = new Map();
         values.map = (name, type, tensor) => {
             if (name.length === 0) {
@@ -381,19 +380,18 @@ tnn.TensorShape = class {
 
 tnn.TextProtoReader = class {
 
-    constructor(stream) {
-        this.stream = stream;
+    constructor(reader) {
+        this.reader = reader;
         this.inputs = [];
         this.outputs = [];
         this.layers = [];
     }
 
     read() {
-        if (this.stream) {
-            const reader = text.Reader.open(this.stream);
+        if (this.reader) {
             let lines = [];
             for (;;) {
-                const line = reader.read();
+                const line = this.reader.read('\n');
                 if (line === undefined) {
                     break;
                 }
@@ -469,7 +467,7 @@ tnn.TextProtoReader = class {
                     this.layers.push(layer);
                 }
             }
-            delete this.stream;
+            delete this.reader;
         }
     }
 };

+ 16 - 4
source/view.js

@@ -3,6 +3,7 @@ import * as base from './base.js';
 import * as zip from './zip.js';
 import * as tar from './tar.js';
 import * as json from './json.js';
+import * as text from './text.js';
 import * as xml from './xml.js';
 import * as protobuf from './protobuf.js';
 import * as flatbuffers from './flatbuffers.js';
@@ -5382,7 +5383,7 @@ view.Context = class {
         return this._content.get(type);
     }
 
-    read(type) {
+    read(type, ...args) {
         if (!this._content.has(type)) {
             switch (type) {
                 case 'json': {
@@ -5419,11 +5420,22 @@ view.Context = class {
                 case 'protobuf.text': {
                     return protobuf.TextReader.open(this._stream);
                 }
+                case 'binary.big-endian': {
+                    return base.BinaryReader.open(this._stream, false);
+                }
                 case 'binary': {
                     return base.BinaryReader.open(this._stream);
                 }
-                case 'binary.big-endian': {
-                    return base.BinaryReader.open(this._stream, false);
+                case 'text': {
+                    if (typeof args[0] === 'number') {
+                        const length = Math.min(this._stream.length, args[0]);
+                        const buffer = this._stream.peek(length);
+                        text.Reader.open(buffer);
+                    }
+                    return text.Reader.open(this._stream);
+                }
+                case 'text.decoder': {
+                    return text.Decoder.open(this._stream);
                 }
                 default: {
                     break;
@@ -5599,7 +5611,7 @@ view.ModelFactoryService = class {
         this.register('./hickle', ['.h5', '.hkl']);
         this.register('./nnef', ['.nnef', '.dat']);
         this.register('./onednn', ['.json']);
-        this.register('./mlir', ['.mlir']);
+        this.register('./mlir', ['.mlir', '.mlir.txt']);
         this.register('./sentencepiece', ['.model']);
         this.register('./hailo', ['.hn', '.har', '.metadata.json']);
         this.register('./nnc', ['.nnc']);