ncnn.js 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941
  1. var ncnn = ncnn || {};
  2. var text = require('./text');
  3. var base = require('./base');
  4. // https://github.com/Tencent/ncnn/wiki/param-and-model-file-structure
  5. // https://github.com/Tencent/ncnn/wiki/operation-param-weight-table
  6. // https://github.com/Tencent/ncnn/wiki/operators
  7. ncnn.ModelFactory = class {
  8. match(context) {
  9. const identifier = context.identifier.toLowerCase();
  10. if (identifier.endsWith('.param.bin') || identifier.endsWith('.ncnnmodel')) {
  11. const stream = context.stream;
  12. if (stream.length > 4) {
  13. const buffer = stream.peek(4);
  14. const signature = (buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24) >>> 0;
  15. if (signature == 0x007685DD) {
  16. return 'ncnn.model.bin';
  17. }
  18. }
  19. }
  20. if (identifier.endsWith('.param') || identifier.endsWith('.cfg.ncnn')) {
  21. try {
  22. const reader = text.Reader.open(context.stream, 2048);
  23. const signature = reader.read();
  24. if (signature !== undefined) {
  25. if (signature.trim() === '7767517') {
  26. return 'ncnn.model';
  27. }
  28. const header = signature.trim().split(' ');
  29. if (header.length === 2 && header.every((value) => value >>> 0 === parseFloat(value))) {
  30. return 'ncnn.model';
  31. }
  32. }
  33. }
  34. catch (err) {
  35. // continue regardless of error
  36. }
  37. }
  38. if (identifier.endsWith('.bin') || identifier.endsWith('.weights.ncnn')) {
  39. if (identifier == 'snapshot_blob.bin' || identifier === 'v8_context_snapshot.bin') {
  40. return undefined;
  41. }
  42. const stream = context.stream;
  43. if (stream.length > 4) {
  44. const buffer = stream.peek(4);
  45. const signature = (buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24) >>> 0;
  46. if (signature === 0x00000000 || signature === 0x00000001 ||
  47. signature === 0x01306B47 || signature === 0x000D4B38 || signature === 0x0002C056) {
  48. return 'ncnn.weights';
  49. }
  50. }
  51. }
  52. return undefined;
  53. }
  54. open(context, match) {
  55. return context.metadata('ncnn-metadata.json').then((metadata) => {
  56. const identifier = context.identifier.toLowerCase();
  57. const openBinary = (param, bin) => {
  58. const reader = new ncnn.BinaryParamReader(metadata, param);
  59. return new ncnn.Model(metadata, reader, bin);
  60. };
  61. const openText = (param, bin) => {
  62. const reader = new ncnn.TextParamReader(param);
  63. return new ncnn.Model(metadata, reader, bin);
  64. };
  65. let bin = null;
  66. switch (match) {
  67. case 'ncnn.model': {
  68. if (identifier.endsWith('.param')) {
  69. bin = context.identifier.substring(0, context.identifier.length - 6) + '.bin';
  70. }
  71. else if (identifier.endsWith('.cfg.ncnn')) {
  72. bin = context.identifier.substring(0, context.identifier.length - 9) + '.weights.ncnn';
  73. }
  74. return context.request(bin, null).then((stream) => {
  75. const buffer = stream.read();
  76. return openText(context.stream.peek(), buffer);
  77. }).catch(() => {
  78. return openText(context.stream.peek(), null);
  79. });
  80. }
  81. case 'ncnn.model.bin': {
  82. bin = context.identifier.substring(0, context.identifier.length - 10) + '.bin';
  83. return context.request(bin, null).then((stream) => {
  84. const buffer = stream.read();
  85. return openBinary(context.stream.peek(), buffer);
  86. }).catch(() => {
  87. return openBinary(context.stream.peek(), null);
  88. });
  89. }
  90. case 'ncnn.weights': {
  91. let content = null;
  92. if (identifier.endsWith('bin')) {
  93. content = context.identifier.substring(0, context.identifier.length - 4) + '.param';
  94. }
  95. else if (identifier.endsWith('.weights.ncnn')) {
  96. content = context.identifier.substring(0, context.identifier.length - 13) + '.cfg.ncnn';
  97. }
  98. return context.request(content, null).then((stream) => {
  99. const buffer = stream.peek();
  100. return openText(buffer, context.stream.peek());
  101. }).catch(() => {
  102. return context.request(content + '.bin', null).then((stream) => {
  103. const buffer = stream.peek();
  104. return openBinary(buffer, context.stream.peek());
  105. });
  106. });
  107. }
  108. default: {
  109. throw new ncnn.Error("Unsupported ncnn format '" + match + "'.");
  110. }
  111. }
  112. });
  113. }
  114. };
  115. ncnn.Model = class {
  116. constructor(metadata, param, bin) {
  117. this._graphs = [
  118. new ncnn.Graph(metadata, param, bin)
  119. ];
  120. }
  121. get format() {
  122. return 'ncnn';
  123. }
  124. get graphs() {
  125. return this._graphs;
  126. }
  127. };
  128. ncnn.Graph = class {
  129. constructor(metadata, param, bin) {
  130. this._inputs = [];
  131. this._outputs = [];
  132. this._nodes = [];
  133. const blobReader = new ncnn.BlobReader(bin);
  134. const layers = param.layers;
  135. const args = new Map();
  136. const arg = (name, type) => {
  137. if (!args.has(name)) {
  138. args.set(name, new ncnn.Argument(name, type, null));
  139. }
  140. return args.get(name);
  141. };
  142. for (const layer of layers) {
  143. const attributes = layer.attributes;
  144. for (const pair of attributes) {
  145. const key = pair[0];
  146. const list = pair[1];
  147. if (key === '30' && Array.isArray(list)) {
  148. const value = list.map((item) => parseInt(item, 10));
  149. for (const output of layer.outputs || []) {
  150. if (value.length > 0 && value[0] <= value.length - 1) {
  151. const shape = new Array(value.shift());
  152. for (let i = 0; i < shape.length; i++) {
  153. shape[i] = value.shift();
  154. }
  155. const type = new ncnn.TensorType('?', new ncnn.TensorShape(shape));
  156. arg(output, type);
  157. }
  158. attributes.delete(key);
  159. }
  160. }
  161. }
  162. }
  163. for (const layer of layers) {
  164. if (layer.type == 'Input') {
  165. const values = Array.from(layer.attributes.values());
  166. const dimensions = values.map((value) => !isNaN(parseInt(value, 10)) ? parseInt(value, 10) : value);
  167. const shape = new ncnn.TensorShape(dimensions);
  168. const type = new ncnn.TensorType('float32', shape);
  169. const input = new ncnn.Parameter(layer.name, true, layer.outputs.map((output) => new ncnn.Argument(output, type, null)));
  170. this._inputs.push(input);
  171. }
  172. else {
  173. const node = new ncnn.Node(metadata, blobReader, layer, arg);
  174. this._nodes.push(node);
  175. }
  176. }
  177. }
  178. get inputs() {
  179. return this._inputs;
  180. }
  181. get outputs() {
  182. return this._outputs;
  183. }
  184. get nodes() {
  185. return this._nodes;
  186. }
  187. };
  188. ncnn.Parameter = class {
  189. constructor(name, visible, args) {
  190. this._name = name;
  191. this._visible = visible;
  192. this._arguments = args;
  193. }
  194. get name() {
  195. return this._name;
  196. }
  197. get visible() {
  198. return this._visible;
  199. }
  200. get arguments() {
  201. return this._arguments;
  202. }
  203. };
  204. ncnn.Argument = class {
  205. constructor(name, type, initializer) {
  206. if (typeof name !== 'string') {
  207. throw new ncnn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  208. }
  209. this._name = name;
  210. this._type = type || null;
  211. this._initializer = initializer || null;
  212. }
  213. get name() {
  214. return this._name;
  215. }
  216. get type() {
  217. if (this._initializer) {
  218. return this._initializer.type;
  219. }
  220. return this._type;
  221. }
  222. get initializer() {
  223. return this._initializer;
  224. }
  225. };
  226. ncnn.Node = class {
  227. constructor(metadata, blobReader, layer, arg) {
  228. this._inputs = [];
  229. this._outputs = [];
  230. this._chain = [];
  231. this._name = layer.name || '';
  232. const type = layer.type;
  233. this._type = metadata.type(type);
  234. const attributeMetadata = this._type && this._type.attributes ? this._type.attributes : [];
  235. const attributes = layer.attributes;
  236. const inputs = layer.inputs || [];
  237. let inputIndex = 0;
  238. if (this._type && this._type.inputs) {
  239. for (const inputDef of this._type.inputs) {
  240. if (inputIndex < inputs.length || inputDef.option != 'optional') {
  241. const inputCount = (inputDef.option == 'variadic') ? (inputs.length - inputIndex) : 1;
  242. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount).filter((id) => id != '' || inputDef.option != 'optional').map((id) => arg(id));
  243. this._inputs.push(new ncnn.Parameter(inputDef.name, true, inputArguments));
  244. inputIndex += inputCount;
  245. }
  246. }
  247. }
  248. this._inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  249. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  250. return new ncnn.Parameter(inputName, true, [ arg(input) ]);
  251. }));
  252. const outputs = layer.outputs || [];
  253. let outputIndex = 0;
  254. if (this._type && this._type.outputs) {
  255. for (const outputDef of this._type.outputs) {
  256. if (outputIndex < outputs.length || outputDef.option != 'optional') {
  257. const outputCount = (outputDef.option == 'variadic') ? (outputs.length - outputIndex) : 1;
  258. const outputArguments = outputs.slice(outputIndex, outputIndex + outputCount).map((id) => arg(id));
  259. this._outputs.push(new ncnn.Parameter(outputDef.name, true, outputArguments));
  260. outputIndex += outputCount;
  261. }
  262. }
  263. }
  264. this._outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  265. const outputName = ((outputIndex + index) == 0) ? 'output' : (outputIndex + index).toString();
  266. return new ncnn.Parameter(outputName, true, [ arg(output) ]);
  267. }));
  268. switch (this._type.name) {
  269. case 'BatchNorm': {
  270. const channels = parseInt(attributes.get('0') || 0, 10);
  271. this._weight(blobReader, 'slope', [ channels ], 'float32');
  272. this._weight(blobReader, 'mean', [ channels ], 'float32');
  273. this._weight(blobReader, 'variance', [ channels ], 'float32');
  274. this._weight(blobReader, 'bias', [ channels ], 'float32');
  275. break;
  276. }
  277. case 'InnerProduct': {
  278. const activation_names = [ '', 'ReLU', 'Leaky ReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish' ];
  279. const activation_type = parseInt(attributes.get('9') || 0, 10);
  280. if (activation_type > 0 && activation_type < activation_names.length) {
  281. const layer = {
  282. type: activation_names[activation_type],
  283. attributes: new Map()
  284. };
  285. this._chain.push(new ncnn.Node(metadata, blobReader, layer, arg));
  286. }
  287. const num_output = parseInt(attributes.get('0') || 0, 10);
  288. const weight_data_size = parseInt(attributes.get('2') || 0, 10);
  289. this._weight(blobReader, 'weight', [ num_output, weight_data_size / num_output ]);
  290. if (parseInt(attributes.get('1') || 0, 10) === 1) {
  291. this._weight(blobReader, 'bias', [ num_output ], 'float32');
  292. }
  293. attributes.delete('2');
  294. break;
  295. }
  296. case 'Bias': {
  297. const bias_data_size = parseInt(attributes.get('0') || 0, 10);
  298. this._weight(blobReader, 'bias', [ bias_data_size ], 'float32');
  299. break;
  300. }
  301. case 'Embed': {
  302. const num_output = parseInt(attributes.get('0') || 0, 10);
  303. const weight_data_size = parseInt(attributes.get('3') || 0, 10);
  304. this._weight(blobReader, 'weight', [ weight_data_size / num_output, num_output ]);
  305. if (parseInt(attributes.get('2') || 0, 10) === 1) {
  306. this._weight(blobReader, 'bias', [ num_output ], 'float32');
  307. }
  308. attributes.get('3');
  309. break;
  310. }
  311. case 'Convolution':
  312. case 'ConvolutionDepthWise':
  313. case 'Deconvolution':
  314. case 'DeconvolutionDepthWise': {
  315. const activation_names = [ '', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish' ];
  316. const activation_type = parseInt(attributes.get('9') || 0, 10);
  317. if (activation_type > 0 && activation_type < activation_names.length) {
  318. const layer = {
  319. type: activation_names[activation_type],
  320. attributes: new Map()
  321. };
  322. this._chain.push(new ncnn.Node(metadata, blobReader, layer, arg));
  323. }
  324. const num_output = parseInt(attributes.get('0') || 0, 10);
  325. const kernel_w = parseInt(attributes.get('1') || 0, 10);
  326. const kernel_h = parseInt(attributes.get('11') || kernel_w, 10);
  327. const weight_data_size = parseInt(attributes.get('6') || 0, 10);
  328. this._weight(blobReader, 'weight', [ num_output, weight_data_size / (num_output * kernel_w * kernel_h), kernel_h, kernel_w ]);
  329. if (parseInt(attributes.get('5') || 0, 10) === 1) {
  330. this._weight(blobReader, 'bias', [ num_output ], 'float32');
  331. }
  332. attributes.delete('6');
  333. break;
  334. }
  335. case 'Convolution1D':
  336. case 'ConvolutionDepthWise1D': {
  337. const activation_names = [ '', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish' ];
  338. const activation_type = parseInt(attributes.get('9') || 0, 10);
  339. if (activation_type > 0 && activation_type < activation_names.length) {
  340. const layer = {
  341. type: activation_names[activation_type],
  342. attributes: new Map()
  343. };
  344. this._chain.push(new ncnn.Node(metadata, blobReader, layer, arg));
  345. }
  346. const num_output = parseInt(attributes.get('0') || 0, 10);
  347. const kernel_w = parseInt(attributes.get('1') || 0, 10);
  348. const weight_data_size = parseInt(attributes.get('6') || 0, 10);
  349. this._weight(blobReader, 'weight', [ num_output, weight_data_size / (num_output * kernel_w), kernel_w ]);
  350. if (parseInt(attributes.get('5') || 0, 10) === 1) {
  351. this._weight(blobReader, 'bias', [ num_output ], 'float32');
  352. }
  353. attributes.delete('6');
  354. break;
  355. }
  356. case 'Convolution3D':
  357. case 'ConvolutionDepthWise3D': {
  358. const activation_names = [ '', 'ReLU', 'LeakyReLU', 'Clip', 'Sigmoid', 'Mish', 'HardSwish' ];
  359. const activation_type = parseInt(attributes.get('9') || 0, 10);
  360. if (activation_type > 0 && activation_type < activation_names.length) {
  361. const layer = {
  362. type: activation_names[activation_type],
  363. attributes: new Map()
  364. };
  365. this._chain.push(new ncnn.Node(metadata, blobReader, layer, arg));
  366. }
  367. const num_output = parseInt(attributes.get('0') || 0, 10);
  368. const kernel_w = parseInt(attributes.get('1') || 0, 10);
  369. const kernel_h = parseInt(attributes.get('11') || kernel_w, 10);
  370. const kernel_d = parseInt(attributes.get('21') || kernel_w, 10);
  371. const weight_data_size = parseInt(attributes.get('6') || 0, 10);
  372. this._weight(blobReader, 'weight', [ num_output, weight_data_size / (num_output * kernel_w * kernel_h * kernel_d), kernel_d, kernel_h, kernel_w ]);
  373. if (parseInt(attributes.get('5') || 0, 10) === 1) {
  374. this._weight(blobReader, 'bias', [ num_output ], 'float32');
  375. }
  376. attributes.delete('6');
  377. break;
  378. }
  379. case 'Quantize': {
  380. const scale_data_size = parseInt(attributes.get('0') || 1, 10);
  381. this._weight(blobReader, 'scale', [ scale_data_size ], 'float32');
  382. break;
  383. }
  384. case 'Dequantize': {
  385. const scale_data_size = parseInt(attributes.get('0') || 1, 10);
  386. const bias_data_size = parseInt(attributes.get('1') || 0, 10);
  387. this._weight(blobReader, 'scale', [ scale_data_size ], 'float32');
  388. this._weight(blobReader, 'bias', [ bias_data_size ], 'float32');
  389. break;
  390. }
  391. case 'Requantize': {
  392. const scale_in_data_size = parseInt(attributes.get('0') || 1, 10);
  393. const scale_out_data_size = parseInt(attributes.get('1') || 1, 10);
  394. const bias_data_size = parseInt(attributes.get('2') || 0, 10);
  395. this._weight(blobReader, 'scale_in', [ scale_in_data_size ], 'float32');
  396. this._weight(blobReader, 'scale_out', [ scale_out_data_size ], 'float32');
  397. this._weight(blobReader, 'bias', [ bias_data_size ], 'float32');
  398. break;
  399. }
  400. case 'InstanceNorm': {
  401. const affine = parseInt(attributes.get('2') || 1, 10);
  402. if (affine === 1) {
  403. const channels = parseInt(attributes.get('0') || 0, 10);
  404. this._weight(blobReader, 'gamma', [ channels ], 'float32');
  405. this._weight(blobReader, 'beta', [ channels ], 'float32');
  406. }
  407. break;
  408. }
  409. case 'Scale': {
  410. const scale_data_size = parseInt(attributes.get('0') || 0, 10);
  411. if (scale_data_size != -233) {
  412. this._weight(blobReader, 'scale', [ scale_data_size], 'float32');
  413. if (attributes.get('1') == '1') {
  414. this._weight(blobReader, 'bias', [ scale_data_size ], 'float32');
  415. }
  416. }
  417. break;
  418. }
  419. case 'Normalize': {
  420. const scale_data_size = parseInt(attributes.get('3') || 0, 10);
  421. this._weight(blobReader, 'scale', [ scale_data_size ], 'float32');
  422. break;
  423. }
  424. case 'PReLU': {
  425. const num_slope = parseInt(attributes.get('0') || 0, 10);
  426. this._weight(blobReader, 'slope', [ num_slope ], 'float32');
  427. break;
  428. }
  429. case 'Padding': {
  430. const per_channel_pad_data_size = parseInt(attributes.get('6') || 0, 10);
  431. this._weight(blobReader, 'per_channel_pad_data', [ per_channel_pad_data_size ], 'float32');
  432. break;
  433. }
  434. case 'MemoryData': {
  435. const w = parseInt(attributes.get('0') || 0, 10);
  436. const h = parseInt(attributes.get('1') || 0, 10);
  437. const d = parseInt(attributes.get('11') || 0, 10);
  438. const c = parseInt(attributes.get('2') || 0, 10);
  439. if (d != 0) {
  440. this._weight(blobReader, 'data', [ c, d, h, w ], 'float32');
  441. }
  442. else if (c != 0) {
  443. this._weight(blobReader, 'data', [ c, h, w ], 'float32');
  444. }
  445. else if (h != 0) {
  446. this._weight(blobReader, 'data', [ h, w ], 'float32');
  447. }
  448. else if (w != 0) {
  449. this._weight(blobReader, 'data', [ w ], 'float32');
  450. }
  451. else {
  452. this._weight(blobReader, 'data', [ 1 ], 'float32');
  453. }
  454. break;
  455. }
  456. case 'GroupNorm': {
  457. const affine = parseInt(attributes.get('3') || 1, 10);
  458. if (affine === 1) {
  459. const channels = parseInt(attributes.get('1') || 0, 10);
  460. this._weight(blobReader, 'gamma', [ channels ], 'float32');
  461. this._weight(blobReader, 'beta', [ channels ], 'float32');
  462. }
  463. break;
  464. }
  465. case 'LayerNorm': {
  466. const channels = parseInt(attributes.get('0') || 0, 10);
  467. this._weight(blobReader, 'gamma', [ channels ], 'float32');
  468. this._weight(blobReader, 'beta', [ channels ], 'float32');
  469. break;
  470. }
  471. case 'RNN': {
  472. const num_output = parseInt(attributes.get('0') || 0, 10);
  473. const weight_data_size = parseInt(attributes.get('1') || 0, 10);
  474. const direction = parseInt(attributes.get('2') || 0, 10);
  475. const num_directions = direction == 2 ? 2 : 1;
  476. this._weight(blobReader, 'weight_xc', [ num_directions, num_output, weight_data_size / num_directions / num_output ]);
  477. this._weight(blobReader, 'bias_c', [ num_directions, num_output ]);
  478. this._weight(blobReader, 'weight_hc', [ num_directions, num_output, num_output ]);
  479. attributes.delete('1');
  480. break;
  481. }
  482. case 'LSTM': {
  483. const num_output = parseInt(attributes.get('0') || 0, 10);
  484. const weight_data_size = parseInt(attributes.get('1') || 0, 10);
  485. const direction = parseInt(attributes.get('2') || 0, 10);
  486. const num_directions = direction == 2 ? 2 : 1;
  487. this._weight(blobReader, 'weight_xc', [ num_directions, 4, num_output, weight_data_size / num_directions / num_output / 4 ]);
  488. this._weight(blobReader, 'bias_c', [ num_directions, 4, num_output ]);
  489. this._weight(blobReader, 'weight_hc', [ num_directions, 4, num_output, num_output ]);
  490. attributes.delete('1');
  491. break;
  492. }
  493. case 'GRU': {
  494. const num_output = parseInt(attributes.get('0') || 0, 10);
  495. const weight_data_size = parseInt(attributes.get('1') || 0, 10);
  496. const direction = parseInt(attributes.get('2') || 0, 10);
  497. const num_directions = direction == 2 ? 2 : 1;
  498. this._weight(blobReader, 'weight_xc', [ num_directions, 3, num_output, weight_data_size / num_directions / num_output / 3 ]);
  499. this._weight(blobReader, 'bias_c', [ num_directions, 4, num_output ]);
  500. this._weight(blobReader, 'weight_hc', [ num_directions, 3, num_output, num_output ]);
  501. attributes.delete('1');
  502. break;
  503. }
  504. case 'MultiHeadAttention': {
  505. const embed_dim = parseInt(attributes.get('0') || 0, 10);
  506. // const num_head = parseInt(attributes.get('1') || 0, 10);
  507. // const weight_data_size = parseInt(attributes.get('2') || 0, 10);
  508. this._weight(blobReader, 'weight_q', [ embed_dim, embed_dim ]);
  509. this._weight(blobReader, 'bias_q', [ embed_dim ], 'float32');
  510. this._weight(blobReader, 'weight_k', [ embed_dim, embed_dim ]);
  511. this._weight(blobReader, 'bias_k', [ embed_dim ], 'float32');
  512. this._weight(blobReader, 'weight_v', [ embed_dim, embed_dim ]);
  513. this._weight(blobReader, 'bias_v', [ embed_dim ], 'float32');
  514. this._weight(blobReader, 'weight_out', [ embed_dim, embed_dim ]);
  515. this._weight(blobReader, 'bias_out', [ embed_dim ], 'float32');
  516. attributes.delete('2');
  517. break;
  518. }
  519. default: {
  520. break;
  521. }
  522. }
  523. this._attributes = Array.from(attributes).map((attribute) => {
  524. const key = attribute[0];
  525. const value = attribute[1];
  526. const metadata = attributeMetadata[key];
  527. return new ncnn.Attribute(metadata, key, value);
  528. });
  529. }
  530. get type() {
  531. return this._type;
  532. }
  533. get name() {
  534. return this._name;
  535. }
  536. get attributes() {
  537. return this._attributes;
  538. }
  539. get inputs() {
  540. return this._inputs;
  541. }
  542. get outputs() {
  543. return this._outputs;
  544. }
  545. get chain() {
  546. return this._chain;
  547. }
  548. _weight(blobReader, name, dimensions, dataType) {
  549. const blob = blobReader.read(dimensions, dataType);
  550. dataType = blob ? (blob.dataType || '?') : (dataType || '?');
  551. const data = blob ? blob.data : null;
  552. this._inputs.push(new ncnn.Parameter(name, true, [
  553. new ncnn.Argument('', null, new ncnn.Tensor(new ncnn.TensorType(dataType, new ncnn.TensorShape(dimensions)), data))
  554. ]));
  555. }
  556. };
  557. ncnn.Attribute = class {
  558. constructor(metadata, key, value) {
  559. this._type = '';
  560. this._name = key;
  561. this._value = value;
  562. if (metadata) {
  563. this._name = metadata.name;
  564. if (metadata.type) {
  565. this._type = metadata.type;
  566. }
  567. switch (this._type) {
  568. case 'int32': {
  569. this._value = parseInt(this._value, 10);
  570. break;
  571. }
  572. case 'float32': {
  573. this._value = parseFloat(this._value);
  574. break;
  575. }
  576. case 'float32[]': {
  577. this._value = this._value.map((v) => parseFloat(v));
  578. break;
  579. }
  580. default: {
  581. if (this._type) {
  582. this._value = ncnn.Utility.value(this._value, this._type);
  583. }
  584. break;
  585. }
  586. }
  587. if (Object.prototype.hasOwnProperty.call(metadata, 'visible') && !metadata.visible) {
  588. this._visible = false;
  589. }
  590. else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  591. if (this._value == metadata.default || (this._value && this._value.toString() == metadata.default.toString())) {
  592. this._visible = false;
  593. }
  594. }
  595. }
  596. }
  597. get type() {
  598. return this._type;
  599. }
  600. get name() {
  601. return this._name;
  602. }
  603. get value() {
  604. return this._value;
  605. }
  606. get visible() {
  607. return this._visible == false ? false : true;
  608. }
  609. };
  610. ncnn.Tensor = class {
  611. constructor(type, data) {
  612. this._type = type;
  613. this._data = data;
  614. }
  615. get category() {
  616. return 'Weights';
  617. }
  618. get type() {
  619. return this._type;
  620. }
  621. get values() {
  622. return this._data;
  623. }
  624. };
  625. ncnn.TensorType = class {
  626. constructor(dataType, shape) {
  627. this._dataType = dataType || '?';
  628. this._shape = shape;
  629. }
  630. get dataType() {
  631. return this._dataType;
  632. }
  633. get shape() {
  634. return this._shape;
  635. }
  636. toString() {
  637. return this._dataType + this._shape.toString();
  638. }
  639. };
  640. ncnn.TensorShape = class {
  641. constructor(dimensions) {
  642. this._dimensions = dimensions;
  643. }
  644. get dimensions() {
  645. return this._dimensions;
  646. }
  647. toString() {
  648. return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',') + ']') : '';
  649. }
  650. };
  651. ncnn.Utility = class {
  652. static value(value, type) {
  653. ncnn.Utility._enum = ncnn.Utility._enum || new Map([
  654. [ 'BinaryOpType', [ 'Add', 'Sub', 'Mul', 'Div', 'Max', 'Min', 'Pow', 'RSub', 'RDiv' ] ],
  655. [ 'CastOpType', [ 'Auto', 'Float32', 'Float16', 'Int8', 'BFloat16' ] ],
  656. [ 'EltwiseType', [ 'Prod', 'Sum', 'Max' ] ],
  657. [ 'PaddingType', [ 'Constant', 'Replicate', 'Reflect' ] ],
  658. [ 'PoolingType', [ 'Max', 'Average' ] ],
  659. [ 'InterpResizeType', [ '', 'Nearest', 'Bilinear', 'Bicubic' ] ],
  660. [ 'PermuteOrderType', [ 'WH WHC WHDC', 'HW HWC HWDC', 'WCH WDHC', 'CWH DWHC', 'HCW HDWC', 'CHW DHWC', 'WHCD', 'HWCD', 'WCHD', 'CWHD', 'HCWD', 'CHWD', 'WDCH', 'DWCH', 'WCDH', 'CWDH', 'DCWH', 'CDWH', 'HDCW', 'DHCW', 'HCDW', 'CHDW', 'DCHW', 'CDHW' ] ],
  661. [ 'ReductionOpType', [ 'Sum', 'ASum', 'SumSq', 'Mean', 'Max', 'Min', 'Prod', 'L1', 'L2', 'LogSum', 'LogSumExp' ] ],
  662. [ 'UnaryOpType', [ 'Abs', 'Neg', 'Floor', 'Ceil', 'Square', 'Sqrt', 'Rsq', 'Exp', 'Log', 'Sin', 'Cos', 'Tan', 'ASin', 'ACos', 'ATan', 'Reciprocal', 'Tanh' ] ]
  663. ]);
  664. if (this._enum.has(type) && typeof value === 'string') {
  665. const index = parseInt(value, 10);
  666. const list = this._enum.get(type);
  667. if (Number.isInteger(index) && index < list.length) {
  668. return list[index];
  669. }
  670. }
  671. return value;
  672. }
  673. };
  674. ncnn.TextParamReader = class {
  675. constructor(buffer) {
  676. const reader = text.Reader.open(buffer);
  677. const lines = [];
  678. for (;;) {
  679. const line = reader.read();
  680. if (line === undefined) {
  681. break;
  682. }
  683. lines.push(line.trim());
  684. }
  685. const signature = lines.shift();
  686. const header = (signature !== '7767517' ? signature : lines.shift()).split(' ');
  687. if (header.length !== 2 || !header.every((value) => value >>> 0 === parseFloat(value))) {
  688. throw new ncnn.Error('Invalid header.');
  689. }
  690. const layers = [];
  691. while (lines.length > 0) {
  692. const line = lines.shift();
  693. if (line.length > 0) {
  694. const columns = line.split(' ').filter((s) => s.length != 0);
  695. const layer = {};
  696. layer.type = columns.shift();
  697. layer.name = columns.shift();
  698. const inputCount = parseInt(columns.shift(), 10);
  699. const outputCount = parseInt(columns.shift(), 10);
  700. layer.inputs = columns.splice(0, inputCount);
  701. layer.outputs = columns.splice(0, outputCount);
  702. layer.attributes = new Map();
  703. const attributes = layer.attributes;
  704. let index = 0;
  705. for (const column of columns) {
  706. const parts = column.split('=');
  707. if (parts.length > 2) {
  708. throw new ncnn.Attribute("Invalid attribute '" + column + "'.");
  709. }
  710. let key = (parts.length === 2) ? parts[0].trim() : index.toString();
  711. let value = (parts.length === 2) ? parts[1].trim() : parts[0].trim();
  712. const keyInt = parseInt(key, 10);
  713. if (keyInt < 0) {
  714. value = value.split(',').map((v) => v.trim());
  715. value.shift();
  716. key = (-(keyInt + 23300)).toString();
  717. }
  718. attributes.set(key, value);
  719. index++;
  720. }
  721. layers.push(layer);
  722. }
  723. }
  724. this._layers = layers;
  725. }
  726. get layers() {
  727. return this._layers;
  728. }
  729. };
  730. ncnn.BinaryParamReader = class {
  731. constructor(metadata, buffer) {
  732. const reader = new base.BinaryReader(buffer);
  733. if (reader.int32() !== 0x007685DD) {
  734. throw new ncnn.Error('Invalid signature.');
  735. }
  736. const layerCount = reader.int32();
  737. /* const blobCount = */ reader.int32();
  738. this._layers = [];
  739. for (let i = 0; i < layerCount; i++) {
  740. const typeIndex = reader.int32();
  741. const operator = metadata.type(typeIndex);
  742. const layer = {
  743. type: operator || typeIndex.toString(),
  744. name: i.toString(),
  745. attributes: new Map(),
  746. inputs: [],
  747. outputs: []
  748. };
  749. const inputCount = reader.int32();
  750. const outputCount = reader.int32();
  751. for (let j = 0; j < inputCount; j++) {
  752. layer.inputs.push(reader.int32().toString());
  753. }
  754. for (let j = 0; j < outputCount; j++) {
  755. layer.outputs.push(reader.int32().toString());
  756. }
  757. const attributes = layer.attributes;
  758. let id = reader.int32();
  759. while (id != -233) {
  760. const isArray = id <= -23300;
  761. if (isArray) {
  762. id = -id - 23300;
  763. }
  764. const key = id.toString();
  765. if (isArray) {
  766. const length = reader.int32();
  767. const values = [];
  768. for (let i = 0; i < length; i++) {
  769. values.push(reader.int32());
  770. }
  771. attributes.set(key, values);
  772. }
  773. else {
  774. const value = reader.int32();
  775. attributes.set(key, value);
  776. }
  777. id = reader.int32();
  778. }
  779. this._layers.push(layer);
  780. }
  781. }
  782. get layers() {
  783. return this._layers;
  784. }
  785. };
  786. ncnn.BlobReader = class {
  787. constructor(buffer) {
  788. this._buffer = buffer;
  789. this._position = 0;
  790. }
  791. read(shape, dataType) {
  792. if (this._buffer) {
  793. if (!dataType) {
  794. if (this._buffer && this._position + 4 < this._buffer.length) {
  795. const f0 = this._buffer[this._position++];
  796. const f1 = this._buffer[this._position++];
  797. const f2 = this._buffer[this._position++];
  798. const f3 = this._buffer[this._position++];
  799. const type = f0 | f1 << 8 | f2 << 16 | f3 << 24;
  800. switch (type) {
  801. case 0x00000000:
  802. dataType = 'float32';
  803. break;
  804. case 0x01306B47:
  805. dataType = 'float16';
  806. break;
  807. case 0x000D4B38:
  808. dataType = 'int8';
  809. break;
  810. case 0x00000001:
  811. dataType = 'qint8';
  812. break;
  813. case 0x0002C056: // size * sizeof(float) - raw data with extra scaling
  814. default:
  815. throw new ncnn.Error("Unsupported weight type '" + type + "'.");
  816. }
  817. }
  818. else {
  819. this._buffer = null;
  820. }
  821. }
  822. let data = null;
  823. let size = 1;
  824. if (shape) {
  825. for (const dimension of shape) {
  826. size *= dimension;
  827. }
  828. }
  829. else {
  830. this._buffer = null;
  831. }
  832. if (this._buffer) {
  833. if (dataType) {
  834. const position = this._position;
  835. switch (dataType) {
  836. case 'float32':
  837. size *= 4;
  838. this._position += size;
  839. data = this._buffer.subarray(position, this._position);
  840. break;
  841. case 'float16':
  842. size *= 2;
  843. this._position += size;
  844. data = this._buffer.subarray(position, this._position);
  845. break;
  846. case 'int8':
  847. this._position += size;
  848. data = this._buffer.subarray(position, this._position);
  849. break;
  850. case 'qint8':
  851. this._position += size + 1024;
  852. data = null;
  853. break;
  854. default:
  855. throw new ncnn.Error("Unsupported weight type '" + dataType + "'.");
  856. }
  857. }
  858. }
  859. return { dataType: dataType, data: data };
  860. }
  861. return null;
  862. }
  863. };
  864. ncnn.Error = class extends Error {
  865. constructor(message) {
  866. super(message);
  867. this.name = 'Error loading ncnn model.';
  868. }
  869. };
  870. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  871. module.exports.ModelFactory = ncnn.ModelFactory;
  872. }