tensorrt.js 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import * as base from './base.js';
  2. const tensorrt = {};
  3. tensorrt.ModelFactory = class {
  4. async match(context) {
  5. const entries = [
  6. tensorrt.Engine,
  7. tensorrt.Container
  8. ];
  9. for (const entry of entries) {
  10. const target = entry.open(context);
  11. if (target) {
  12. return context.set(target.type, target);
  13. }
  14. }
  15. return null;
  16. }
  17. async open(context) {
  18. const target = context.value;
  19. await target.read();
  20. return new tensorrt.Model(null, target);
  21. }
  22. };
  23. tensorrt.Model = class {
  24. constructor(metadata, model) {
  25. this.format = model.format;
  26. this.modules = [new tensorrt.Graph(metadata, model)];
  27. }
  28. };
  29. tensorrt.Graph = class {
  30. constructor(/* metadata, model */) {
  31. this.inputs = [];
  32. this.outputs = [];
  33. this.nodes = [];
  34. }
  35. };
  36. tensorrt.Engine = class {
  37. static open(context) {
  38. const stream = context.stream;
  39. if (stream && stream.length >= 4) {
  40. const size = Math.min(stream.length, 24);
  41. let buffer = stream.peek(size);
  42. let offset = 0;
  43. if (size >= 24) {
  44. if (buffer[3] === 0x00 && buffer[4] === 0x7b) {
  45. const reader = base.BinaryReader.open(buffer);
  46. offset = reader.uint32() + 4;
  47. if ((offset + 4) < stream.length) {
  48. const position = stream.position;
  49. stream.seek(offset);
  50. buffer = stream.peek(4);
  51. stream.seek(position);
  52. }
  53. }
  54. }
  55. const signature = String.fromCharCode.apply(null, buffer.slice(0, 4));
  56. if (signature === 'ptrt' || signature === 'ftrt') {
  57. return new tensorrt.Engine(context, offset);
  58. }
  59. }
  60. return null;
  61. }
  62. constructor(context, position) {
  63. this.type = 'tensorrt.engine';
  64. this.format = 'TensorRT Engine';
  65. this.context = context;
  66. this.position = position;
  67. }
  68. async read() {
  69. const context = this.context;
  70. const reader = await context.read('binary');
  71. const offset = this.position + 24;
  72. if (offset <= reader.length) {
  73. reader.skip(this.position);
  74. const buffer = reader.peek(24);
  75. delete this.context;
  76. delete this.position;
  77. reader.skip(4);
  78. const version = reader.uint32();
  79. reader.uint32();
  80. // let size = 0;
  81. switch (version) {
  82. case 0x0000:
  83. case 0x002B: {
  84. reader.uint32();
  85. /* size = */ reader.uint64();
  86. break;
  87. }
  88. case 0x0057:
  89. case 0x0059:
  90. case 0x0060:
  91. case 0x0061: {
  92. /* size = */ reader.uint64();
  93. reader.uint32();
  94. break;
  95. }
  96. default: {
  97. const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  98. throw new tensorrt.Error(`Unsupported TensorRT engine signature (${content.substring(8)}).`);
  99. }
  100. }
  101. }
  102. // const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  103. // buffer = this.stream.read(24 + size);
  104. // reader = new tensorrt.BinaryReader(buffer);
  105. throw new tensorrt.Error('Invalid file content. File contains undocumented TensorRT engine data.');
  106. }
  107. };
  108. tensorrt.Container = class {
  109. static open(context) {
  110. const stream = context.stream;
  111. if (stream) {
  112. const buffer = stream.peek(Math.min(512, stream.length));
  113. if (buffer.length > 12 && buffer[6] === 0x00 && buffer[7] === 0x00) {
  114. const reader = base.BinaryReader.open(buffer);
  115. const length = reader.uint64().toNumber();
  116. if (length === stream.length) {
  117. let position = reader.position + reader.uint32();
  118. if (position < reader.length) {
  119. reader.seek(position);
  120. const offset = reader.uint32();
  121. position = reader.position - offset - 4;
  122. if (position > 0 && position < reader.length) {
  123. reader.seek(position);
  124. const length = reader.uint16();
  125. if (offset === length) {
  126. return new tensorrt.Container(stream);
  127. }
  128. }
  129. }
  130. }
  131. }
  132. }
  133. return null;
  134. }
  135. constructor(stream) {
  136. this.type = 'tensorrt.container';
  137. this.format = 'TensorRT FlatBuffers';
  138. this.stream = stream;
  139. }
  140. async read() {
  141. delete this.stream;
  142. // const buffer = this.stream.peek(Math.min(24, this.stream.length));
  143. // const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  144. throw new tensorrt.Error('Invalid file content. File contains undocumented TensorRT data.');
  145. }
  146. };
  147. tensorrt.BinaryReader = class {
  148. constructor(reader) {
  149. this._reader = reader;
  150. }
  151. get position() {
  152. return this._reader.position;
  153. }
  154. uint64() {
  155. return this._reader.uint64();
  156. }
  157. string() {
  158. const length = this.uint64().toNumber();
  159. const position = this.position;
  160. this.skip(length);
  161. const data = this._buffer.subarray(position, this.position);
  162. this._decoder = this._decoder || new TextDecoder('utf-8');
  163. return this._decoder.decode(data);
  164. }
  165. };
  166. tensorrt.Error = class extends Error {
  167. constructor(message) {
  168. super(message);
  169. this.name = 'Error loading TensorRT model.';
  170. }
  171. };
  172. export const ModelFactory = tensorrt.ModelFactory;