rknn.js 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. /* jshint esversion: 6 */
  2. var rknn = rknn || {};
  3. var json = json || require('./json');
  4. rknn.ModelFactory = class {
  5. match(context) {
  6. const stream = context.stream;
  7. const signature = [ 0x52, 0x4B, 0x4E, 0x4E, 0x00, 0x00, 0x00, 0x00 ];
  8. if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  9. return true;
  10. }
  11. return false;
  12. }
  13. open(context) {
  14. return rknn.Metadata.open(context).then((metadata) => {
  15. const stream = context.stream;
  16. const container = rknn.Container.open(stream);
  17. return new rknn.Model(metadata, container.model, container.weights);
  18. });
  19. }
  20. };
  21. rknn.Model = class {
  22. constructor(metadata, model, weights) {
  23. this._version = model.version;
  24. this._producer = model.ori_network_platform || model.network_platform || '';
  25. this._runtime = model.target_platform ? model.target_platform.join(',') : '';
  26. this._graphs = [ new rknn.Graph(metadata, model, weights) ];
  27. }
  28. get format() {
  29. return 'RKNN v' + this._version;
  30. }
  31. get producer() {
  32. return this._producer;
  33. }
  34. get runtime() {
  35. return this._runtime;
  36. }
  37. get graphs() {
  38. return this._graphs;
  39. }
  40. };
  41. rknn.Graph = class {
  42. constructor(metadata, model, weights) {
  43. this._name = model.name || '';
  44. this._inputs = [];
  45. this._outputs = [];
  46. this._nodes = [];
  47. const args = new Map();
  48. for (const const_tensor of model.const_tensor) {
  49. const name = 'const_tensor:' + const_tensor.tensor_id.toString();
  50. const shape = new rknn.TensorShape(const_tensor.size);
  51. const type = new rknn.TensorType(const_tensor.dtype, shape);
  52. const tensor = new rknn.Tensor(type, const_tensor.offset, weights);
  53. const argument = new rknn.Argument(name, type, tensor);
  54. args.set(name, argument);
  55. }
  56. for (const virtual_tensor of model.virtual_tensor) {
  57. const name = virtual_tensor.node_id.toString() + ':' + virtual_tensor.output_port.toString();
  58. const argument = new rknn.Argument(name, null, null);
  59. args.set(name, argument);
  60. }
  61. for (const norm_tensor of model.norm_tensor) {
  62. const name = 'norm_tensor:' + norm_tensor.tensor_id.toString();
  63. const shape = new rknn.TensorShape(norm_tensor.size);
  64. const type = new rknn.TensorType(norm_tensor.dtype, shape);
  65. const argument = new rknn.Argument(name, type, null);
  66. args.set(name, argument);
  67. }
  68. for (const node of model.nodes) {
  69. node.input = [];
  70. node.output = [];
  71. }
  72. for (const connection of model.connection) {
  73. switch (connection.left) {
  74. case 'input':
  75. model.nodes[connection.node_id].input.push(connection);
  76. if (connection.right_node) {
  77. model.nodes[connection.right_node.node_id].output[connection.right_node.tensor_id] = connection;
  78. }
  79. break;
  80. case 'output':
  81. model.nodes[connection.node_id].output.push(connection);
  82. break;
  83. }
  84. }
  85. for (const graph of model.graph) {
  86. const key = graph.right + ':' + graph.right_tensor_id.toString();
  87. const argument = args.get(key);
  88. if (!argument) {
  89. throw new rknn.Error("Invalid argument '" + key + "'.");
  90. }
  91. const name = graph.left + ((graph.left_tensor_id === 0) ? '' : graph.left_tensor_id.toString());
  92. const parameter = new rknn.Parameter(name, [ argument ]);
  93. switch (graph.left) {
  94. case 'input': {
  95. this._inputs.push(parameter);
  96. break;
  97. }
  98. case 'output': {
  99. this._outputs.push(parameter);
  100. break;
  101. }
  102. }
  103. }
  104. for (const node of model.nodes) {
  105. this._nodes.push(new rknn.Node(metadata, node, args));
  106. }
  107. }
  108. get name() {
  109. return this._name;
  110. }
  111. get inputs() {
  112. return this._inputs;
  113. }
  114. get outputs() {
  115. return this._outputs;
  116. }
  117. get nodes() {
  118. return this._nodes;
  119. }
  120. };
  121. rknn.Parameter = class {
  122. constructor(name, args) {
  123. this._name = name;
  124. this._arguments = args;
  125. }
  126. get name() {
  127. return this._name;
  128. }
  129. get visible() {
  130. return true;
  131. }
  132. get arguments() {
  133. return this._arguments;
  134. }
  135. };
  136. rknn.Argument = class {
  137. constructor(name, type, initializer) {
  138. if (typeof name !== 'string') {
  139. throw new rknn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  140. }
  141. this._name = name;
  142. this._type = type || null;
  143. this._initializer = initializer || null;
  144. }
  145. get name() {
  146. return this._name;
  147. }
  148. get type() {
  149. return this._type;
  150. }
  151. get initializer() {
  152. return this._initializer;
  153. }
  154. };
  155. rknn.Node = class {
  156. constructor(metadata, node, args) {
  157. this._metadata = metadata;
  158. this._name = node.name || '';
  159. this._type = node.op;
  160. this._inputs = [];
  161. this._outputs = [];
  162. this._attributes = [];
  163. const schema = this._metadata.type(this._type);
  164. node.input = node.input || [];
  165. for (let i = 0; i < node.input.length; ) {
  166. const input = schema && schema.inputs && i < schema.inputs.length ? schema.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
  167. const count = input.list ? node.input.length - i : 1;
  168. const list = node.input.slice(i, i + count).map((input) => {
  169. if (input.right_tensor) {
  170. const key = input.right_tensor.type + ':' + input.right_tensor.tensor_id.toString();
  171. const argument = args.get(key);
  172. if (!argument) {
  173. throw new rknn.Error("Invalid input argument '" + key + "'.");
  174. }
  175. return argument;
  176. }
  177. if (input.right_node) {
  178. const key = input.right_node.node_id.toString() + ':' + input.right_node.tensor_id.toString();
  179. const argument = args.get(key);
  180. if (!argument) {
  181. throw new rknn.Error("Invalid input argument '" + key + "'.");
  182. }
  183. return argument;
  184. }
  185. throw new rknn.Error('Invalid input argument.');
  186. });
  187. this._inputs.push(new rknn.Parameter(input.name, list));
  188. i += count;
  189. }
  190. node.output = node.output || [];
  191. for (let i = 0; i < node.output.length; ) {
  192. const output = schema && schema.outputs && i < schema.outputs.length ? schema.outputs[i] : { name: i === 0 ? 'output' : i.toString() };
  193. const count = output.list ? node.output.length - i : 1;
  194. const list = node.output.slice(i, i + count).map((output) => {
  195. if (output.right_tensor) {
  196. const key = output.right_tensor.type + ':' + output.right_tensor.tensor_id.toString();
  197. const argument = args.get(key);
  198. if (!argument) {
  199. throw new rknn.Error("Invalid output argument '" + key + "'.");
  200. }
  201. return argument;
  202. }
  203. if (output.right_node) {
  204. const key = output.right_node.node_id.toString() + ':' + output.right_node.tensor_id.toString();
  205. const argument = args.get(key);
  206. if (!argument) {
  207. throw new rknn.Error("Invalid output argument '" + key + "'.");
  208. }
  209. return argument;
  210. }
  211. throw new rknn.Error('Invalid output argument.');
  212. });
  213. this._outputs.push(new rknn.Parameter(output.name, list));
  214. i += count;
  215. }
  216. if (node.nn) {
  217. const nn = node.nn;
  218. for (const key of Object.keys(nn)) {
  219. const params = nn[key];
  220. for (const name of Object.keys(params)) {
  221. const value = params[name];
  222. this._attributes.push(new rknn.Attribute(name, value));
  223. }
  224. }
  225. }
  226. }
  227. get name() {
  228. return this._name;
  229. }
  230. get type() {
  231. const prefix = 'VSI_NN_OP_';
  232. return this._type.startsWith(prefix) ? this._type.substring(prefix.length) : this.type;
  233. }
  234. get metadata() {
  235. return this._metadata.type(this._type);
  236. }
  237. get inputs() {
  238. return this._inputs;
  239. }
  240. get outputs() {
  241. return this._outputs;
  242. }
  243. get attributes() {
  244. return this._attributes;
  245. }
  246. };
  247. rknn.Attribute = class {
  248. constructor(name, value) {
  249. this._name = name;
  250. this._value = value;
  251. }
  252. get name() {
  253. return this._name;
  254. }
  255. get value() {
  256. return this._value;
  257. }
  258. };
  259. rknn.Tensor = class {
  260. constructor(type, offset, weights) {
  261. this._type = type;
  262. let size = 0;
  263. switch (this._type.dataType) {
  264. case 'uint8': size = 1; break;
  265. case 'int8': size = 1; break;
  266. case 'int32': size = 4; break;
  267. case 'float16': size = 2; break;
  268. case 'float32': size = 4; break;
  269. }
  270. const shape = type.shape.dimensions;
  271. size = size * (shape.length === 0 ? 1 : shape.reduce((a, b) => a * b));
  272. if (size > 0) {
  273. this._data = weights.slice(offset, offset + size);
  274. }
  275. }
  276. get type() {
  277. return this._type;
  278. }
  279. get state() {
  280. return this._context().state || null;
  281. }
  282. get value() {
  283. const context = this._context();
  284. if (context.state) {
  285. return null;
  286. }
  287. context.limit = Number.MAX_SAFE_INTEGER;
  288. return this._decode(context, 0);
  289. }
  290. toString() {
  291. const context = this._context();
  292. if (context.state) {
  293. return '';
  294. }
  295. context.limit = 10000;
  296. const value = this._decode(context, 0);
  297. return JSON.stringify(value, '', ' ');
  298. }
  299. _context() {
  300. const context = {};
  301. if (!this._type.dataType) {
  302. context.state = 'Tensor data type is not implemented.';
  303. return context;
  304. }
  305. if (!this._data) {
  306. context.state = 'Tensor data is empty.';
  307. return context;
  308. }
  309. context.index = 0;
  310. context.count = 0;
  311. context.shape = this._type.shape.dimensions;
  312. context.dataType = this._type.dataType;
  313. context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  314. return context;
  315. }
  316. _decode(context, dimension) {
  317. const shape = context.shape.length !== 0 ? context.shape : [ 1 ];
  318. const results = [];
  319. const size = shape[dimension];
  320. if (dimension == shape.length - 1) {
  321. for (let i = 0; i < size; i++) {
  322. if (context.count > context.limit) {
  323. results.push('...');
  324. return results;
  325. }
  326. switch (context.dataType) {
  327. case 'float16':
  328. results.push(context.view.getFloat16(context.index, true));
  329. context.index += 2;
  330. context.count++;
  331. break;
  332. case 'float32':
  333. results.push(context.view.getFloat32(context.index, true));
  334. context.index += 4;
  335. context.count++;
  336. break;
  337. case 'uint8':
  338. results.push(context.view.getUint8(context.index, true));
  339. context.index++;
  340. context.count++;
  341. break;
  342. case 'int8':
  343. results.push(context.view.getInt8(context.index, true));
  344. context.index += 1;
  345. context.count++;
  346. break;
  347. case 'int32':
  348. results.push(context.view.getInt32(context.index, true));
  349. context.index += 4;
  350. context.count++;
  351. break;
  352. }
  353. }
  354. }
  355. else {
  356. for (let j = 0; j < size; j++) {
  357. if (context.count > context.limit) {
  358. results.push('...');
  359. return results;
  360. }
  361. results.push(this._decode(context, dimension + 1));
  362. }
  363. }
  364. if (context.shape.length == 0) {
  365. return results[0];
  366. }
  367. return results;
  368. }
  369. };
  370. rknn.TensorType = class {
  371. constructor(dataType, shape) {
  372. switch (dataType.vx_type) {
  373. case 'VSI_NN_TYPE_UINT8': this._dataType = 'uint8'; break;
  374. case 'VSI_NN_TYPE_INT8': this._dataType = 'int8'; break;
  375. case 'VSI_NN_TYPE_INT16': this._dataType = 'int16'; break;
  376. case 'VSI_NN_TYPE_INT32': this._dataType = 'int32'; break;
  377. case 'VSI_NN_TYPE_INT64': this._dataType = 'int64'; break;
  378. case 'VSI_NN_TYPE_FLOAT16': this._dataType = 'float16'; break;
  379. case 'VSI_NN_TYPE_FLOAT32': this._dataType = 'float32'; break;
  380. default:
  381. throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
  382. }
  383. this._shape = shape;
  384. }
  385. get dataType() {
  386. return this._dataType;
  387. }
  388. get shape() {
  389. return this._shape;
  390. }
  391. toString() {
  392. return this.dataType + this._shape.toString();
  393. }
  394. };
  395. rknn.TensorShape = class {
  396. constructor(shape) {
  397. this._dimensions = shape;
  398. }
  399. get dimensions() {
  400. return this._dimensions;
  401. }
  402. toString() {
  403. if (!this._dimensions || this._dimensions.length == 0) {
  404. return '';
  405. }
  406. return '[' + this._dimensions.join(',') + ']';
  407. }
  408. };
  409. rknn.Container = class {
  410. static open(stream) {
  411. const signature = [ 0x52, 0x4B, 0x4E, 0x4E, 0x00, 0x00, 0x00, 0x00 ];
  412. if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  413. return new rknn.Container(stream);
  414. }
  415. return null;
  416. }
  417. constructor(stream) {
  418. this._reader = new rknn.Container.StreamReader(stream);
  419. }
  420. get version() {
  421. this._read();
  422. return this._version;
  423. }
  424. get weights() {
  425. this._read();
  426. return this._weights;
  427. }
  428. get model() {
  429. this._read();
  430. return this._model;
  431. }
  432. _read() {
  433. if (this._reader) {
  434. this._reader.uint64();
  435. this._version = this._reader.uint64();
  436. this._weights = this._reader.read();
  437. const buffer = this._reader.read();
  438. const reader = json.TextReader.create(buffer);
  439. this._model = reader.read();
  440. delete this._reader;
  441. }
  442. }
  443. };
  444. rknn.Container.StreamReader = class {
  445. constructor(stream) {
  446. this._stream = stream;
  447. this._length = stream.length;
  448. this._position = 0;
  449. }
  450. skip(offset) {
  451. this._position += offset;
  452. if (this._position > this._length) {
  453. throw new rknn.Error('Expected ' + (this._position - this._length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  454. }
  455. }
  456. uint64() {
  457. this.skip(8);
  458. const buffer = this._stream.read(8);
  459. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  460. return view.getUint64(0, true).toNumber();
  461. }
  462. read() {
  463. const size = this.uint64();
  464. this.skip(size);
  465. return this._stream.read(size);
  466. }
  467. };
  468. rknn.Metadata = class {
  469. static open(context) {
  470. if (rknn.Metadata._metadata) {
  471. return Promise.resolve(rknn.Metadata._metadata);
  472. }
  473. return context.request('rknn-metadata.json', 'utf-8', null).then((data) => {
  474. rknn.Metadata._metadata = new rknn.Metadata(data);
  475. return rknn.Metadata._metadata;
  476. }).catch(() => {
  477. rknn.Metadata._metadata = new rknn.Metadata(null);
  478. return rknn.Metadata._metadata;
  479. });
  480. }
  481. constructor(data) {
  482. this._map = new Map();
  483. if (data) {
  484. const metadata = JSON.parse(data);
  485. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  486. }
  487. }
  488. type(name) {
  489. return this._map.has(name) ? this._map.get(name) : null;
  490. }
  491. attribute(type, name) {
  492. const schema = this.type(type);
  493. if (schema) {
  494. let attributeMap = schema.attributeMap;
  495. if (!attributeMap) {
  496. attributeMap = {};
  497. if (schema.attributes) {
  498. for (const attribute of schema.attributes) {
  499. attributeMap[attribute.name] = attribute;
  500. }
  501. }
  502. schema.attributeMap = attributeMap;
  503. }
  504. const attributeSchema = attributeMap[name];
  505. if (attributeSchema) {
  506. return attributeSchema;
  507. }
  508. }
  509. return null;
  510. }
  511. };
  512. rknn.Error = class extends Error {
  513. constructor(message) {
  514. super(message);
  515. this.name = 'Error loading RKNN model.';
  516. }
  517. };
  518. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  519. module.exports.ModelFactory = rknn.ModelFactory;
  520. }