tensorrt.js 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. var tensorrt = {};
  2. var base = require('./base');
  3. tensorrt.ModelFactory = class {
  4. match(context) {
  5. const stream = context.stream;
  6. return tensorrt.Engine.open(stream) || tensorrt.Container.open(stream);
  7. }
  8. open(context, match) {
  9. return Promise.resolve().then(() => new tensorrt.Model(null, match));
  10. }
  11. };
  12. tensorrt.Model = class {
  13. constructor(metadata, model) {
  14. this._format = model.format;
  15. this._graphs = [ new tensorrt.Graph(metadata, model) ];
  16. }
  17. get format() {
  18. return this._format;
  19. }
  20. get graphs() {
  21. return this._graphs;
  22. }
  23. };
  24. tensorrt.Graph = class {
  25. constructor(/* metadata, model */) {
  26. this._inputs = [];
  27. this._outputs = [];
  28. this._nodes = [];
  29. }
  30. get inputs() {
  31. return this._inputs;
  32. }
  33. get outputs() {
  34. return this._outputs;
  35. }
  36. get nodes() {
  37. return this._nodes;
  38. }
  39. };
  40. tensorrt.Engine = class {
  41. static open(stream) {
  42. const signature = [ 0x70, 0x74, 0x72, 0x74 ]; // ptrt
  43. if (stream && stream.length >= 24 && stream.peek(4).every((value, index) => value === signature[index])) {
  44. return new tensorrt.Engine(stream);
  45. }
  46. return null;
  47. }
  48. constructor(stream) {
  49. this._stream = stream;
  50. }
  51. get format() {
  52. this._read();
  53. return 'TensorRT Engine';
  54. }
  55. _read() {
  56. if (this._stream) {
  57. let buffer = this._stream.peek(24);
  58. const reader = new base.BinaryReader(buffer);
  59. reader.skip(4);
  60. const version = reader.uint32();
  61. reader.uint32();
  62. let size = 0;
  63. switch (version) {
  64. case 0x0000:
  65. case 0x002B: {
  66. reader.uint32();
  67. size = reader.uint64();
  68. break;
  69. }
  70. case 0x0057:
  71. case 0x0059:
  72. case 0x0060:
  73. case 0x0061: {
  74. size = reader.uint64();
  75. reader.uint32();
  76. break;
  77. }
  78. default: {
  79. const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  80. throw new tensorrt.Error("Unsupported TensorRT engine signature (" + content.substring(8) + ").");
  81. }
  82. }
  83. const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  84. buffer = this._stream.read(24 + size);
  85. /* reader = */ new tensorrt.BinaryReader(buffer);
  86. throw new tensorrt.Error("Invalid file content. File contains undocumented TensorRT engine data (" + content.substring(8) + ").");
  87. }
  88. }
  89. };
  90. tensorrt.Container = class {
  91. static open(stream) {
  92. if (stream) {
  93. const buffer = stream.peek(Math.min(512, stream.length));
  94. if (buffer.length > 12 && buffer[6] === 0x00 && buffer[7] === 0x00) {
  95. const reader = new base.BinaryReader(buffer);
  96. const length = reader.uint64();
  97. if (length === stream.length) {
  98. let position = reader.position + reader.uint32();
  99. if (position < reader.length) {
  100. reader.seek(position);
  101. const offset = reader.uint32();
  102. position = reader.position - offset - 4;
  103. if (position > 0 && position < reader.length) {
  104. reader.seek(position);
  105. const length = reader.uint16();
  106. if (offset === length) {
  107. return new tensorrt.Container(stream);
  108. }
  109. }
  110. }
  111. }
  112. }
  113. }
  114. return null;
  115. }
  116. constructor(stream) {
  117. this._stream = stream;
  118. }
  119. get format() {
  120. this._read();
  121. return 'TensorRT FlatBuffers';
  122. }
  123. _read() {
  124. const buffer = this._stream.peek(Math.min(24, this._stream.length));
  125. const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  126. throw new tensorrt.Error('Invalid file content. File contains undocumented TensorRT data (' + content.substring(16) + ').');
  127. }
  128. };
  129. tensorrt.BinaryReader = class extends base.BinaryReader {
  130. string() {
  131. const length = this.uint64();
  132. const position = this._position;
  133. this.skip(length);
  134. const data = this._buffer.subarray(position, this._position);
  135. this._decoder = this._decoder || new TextDecoder('utf-8');
  136. return this._decoder.decode(data);
  137. }
  138. };
  139. tensorrt.Error = class extends Error {
  140. constructor(message) {
  141. super(message);
  142. this.name = 'Error loading TensorRT model.';
  143. }
  144. };
  145. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  146. module.exports.ModelFactory = tensorrt.ModelFactory;
  147. }