| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import * as base from './base.js';
- const tensorrt = {};
- tensorrt.ModelFactory = class {
- async match(context) {
- const entries = [
- tensorrt.Engine,
- tensorrt.Container
- ];
- for (const entry of entries) {
- const target = entry.open(context);
- if (target) {
- return context.set(target.type, target);
- }
- }
- return null;
- }
- async open(context) {
- const target = context.value;
- await target.read();
- return new tensorrt.Model(null, target);
- }
- };
- tensorrt.Model = class {
- constructor(metadata, model) {
- this.format = model.format;
- this.modules = [new tensorrt.Graph(metadata, model)];
- }
- };
- tensorrt.Graph = class {
- constructor(/* metadata, model */) {
- this.inputs = [];
- this.outputs = [];
- this.nodes = [];
- }
- };
- tensorrt.Engine = class {
- static open(context) {
- const stream = context.stream;
- if (stream && stream.length >= 4) {
- const size = Math.min(stream.length, 24);
- let buffer = stream.peek(size);
- let offset = 0;
- if (size >= 24) {
- if (buffer[3] === 0x00 && buffer[4] === 0x7b) {
- const reader = base.BinaryReader.open(buffer);
- offset = reader.uint32() + 4;
- if ((offset + 4) < stream.length) {
- const position = stream.position;
- stream.seek(offset);
- buffer = stream.peek(4);
- stream.seek(position);
- }
- }
- }
- const signature = String.fromCharCode.apply(null, buffer.slice(0, 4));
- if (signature === 'ptrt' || signature === 'ftrt') {
- return new tensorrt.Engine(context, offset);
- }
- }
- return null;
- }
- constructor(context, position) {
- this.type = 'tensorrt.engine';
- this.format = 'TensorRT Engine';
- this.context = context;
- this.position = position;
- }
- async read() {
- const context = this.context;
- const reader = await context.read('binary');
- const offset = this.position + 24;
- if (offset <= reader.length) {
- reader.skip(this.position);
- const buffer = reader.peek(24);
- delete this.context;
- delete this.position;
- reader.skip(4);
- const version = reader.uint32();
- reader.uint32();
- // let size = 0;
- switch (version) {
- case 0x0000:
- case 0x002B: {
- reader.uint32();
- /* size = */ reader.uint64();
- break;
- }
- case 0x0057:
- case 0x0059:
- case 0x0060:
- case 0x0061: {
- /* size = */ reader.uint64();
- reader.uint32();
- break;
- }
- default: {
- const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
- throw new tensorrt.Error(`Unsupported TensorRT engine signature (${content.substring(8)}).`);
- }
- }
- }
- // const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
- // buffer = this.stream.read(24 + size);
- // reader = new tensorrt.BinaryReader(buffer);
- throw new tensorrt.Error('Invalid file content. File contains undocumented TensorRT engine data.');
- }
- };
- tensorrt.Container = class {
- static open(context) {
- const stream = context.stream;
- if (stream) {
- const buffer = stream.peek(Math.min(512, stream.length));
- if (buffer.length > 12 && buffer[6] === 0x00 && buffer[7] === 0x00) {
- const reader = base.BinaryReader.open(buffer);
- const length = reader.uint64().toNumber();
- if (length === stream.length) {
- let position = reader.position + reader.uint32();
- if (position < reader.length) {
- reader.seek(position);
- const offset = reader.uint32();
- position = reader.position - offset - 4;
- if (position > 0 && position < reader.length) {
- reader.seek(position);
- const length = reader.uint16();
- if (offset === length) {
- return new tensorrt.Container(stream);
- }
- }
- }
- }
- }
- }
- return null;
- }
- constructor(stream) {
- this.type = 'tensorrt.container';
- this.format = 'TensorRT FlatBuffers';
- this.stream = stream;
- }
- async read() {
- delete this.stream;
- // const buffer = this.stream.peek(Math.min(24, this.stream.length));
- // const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
- throw new tensorrt.Error('Invalid file content. File contains undocumented TensorRT data.');
- }
- };
- tensorrt.BinaryReader = class {
- constructor(reader) {
- this._reader = reader;
- }
- get position() {
- return this._reader.position;
- }
- uint64() {
- return this._reader.uint64();
- }
- string() {
- const length = this.uint64().toNumber();
- const position = this.position;
- this.skip(length);
- const data = this._buffer.subarray(position, this.position);
- this._decoder = this._decoder || new TextDecoder('utf-8');
- return this._decoder.decode(data);
- }
- };
- tensorrt.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading TensorRT model.';
- }
- };
- export const ModelFactory = tensorrt.ModelFactory;
|