xmodel.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. var xmodel = xmodel || {};
  2. var protobuf = protobuf || require('./protobuf');
  3. xmodel.ModelFactory = class {
  4. match(context) {
  5. const tags = context.tags('pb');
  6. if (tags.get(5) === 2) {
  7. return 'xmodel.pb';
  8. }
  9. return undefined;
  10. }
  11. open(context) {
  12. return context.require('./xmodel-proto').then(() => {
  13. let graph = null;
  14. try {
  15. xmodel.proto = protobuf.get('xmodel').serial_v2;
  16. const stream = context.stream;
  17. const reader = protobuf.BinaryReader.open(stream);
  18. graph = xmodel.proto.Graph.decode(reader);
  19. }
  20. catch (error) {
  21. const message = error && error.message ? error.message : error.toString();
  22. throw new xmodel.Error('File format is not serial_v2.Graph (' + message.replace(/\.$/, '') + ').');
  23. }
  24. return new xmodel.Model(graph);
  25. });
  26. }
  27. };
  28. xmodel.Model = class {
  29. constructor(graph) {
  30. this._name = graph.graph_name || '';
  31. this._format = 'xmodel';
  32. this._producer = graph && graph.graph_attr && graph.graph_attr.origin && graph.graph_attr.origin.string_value ? graph.graph_attr.origin.string_value : '';
  33. this._graphs = [ new xmodel.Graph(graph) ];
  34. }
  35. get name() {
  36. return this._name;
  37. }
  38. get format() {
  39. return this._format;
  40. }
  41. get producer() {
  42. return this._producer;
  43. }
  44. get graphs() {
  45. return this._graphs;
  46. }
  47. };
  48. xmodel.Graph = class {
  49. constructor(graph) {
  50. const metadata = new xmodel.Metadata(graph.op_defs);
  51. this._inputs = [];
  52. this._outputs = [];
  53. const counts = new Map();
  54. for (const op_node of graph.op_node) {
  55. for (const arg of op_node.args) {
  56. for (const arg_op of arg.arg_ops) {
  57. counts.set(arg_op, counts.has(arg_op) ? counts.get(arg_op) + 1 : 1);
  58. }
  59. }
  60. }
  61. const args = new Map();
  62. const arg = (name, node, initializer) => {
  63. if (!args.has(name)) {
  64. args.set(name, new xmodel.Argument(name, node, initializer));
  65. }
  66. return args.get(name);
  67. };
  68. const nodes = [];
  69. for (const node of graph.op_node) {
  70. if (node.args.length === 0) {
  71. if (node.op_type === 'data' || node.op_type === 'data-fix') {
  72. const argument = arg(node.op_name, node);
  73. this._inputs.push(new xmodel.Parameter(node.op_name, [ argument ]));
  74. continue;
  75. }
  76. }
  77. if (node.args.length === 0 && counts.get(node.op_name) === 1) {
  78. if (node.op_type === 'const' || node.op_type === 'const-fix') {
  79. arg(node.op_name, node, true);
  80. continue;
  81. }
  82. }
  83. arg(node.op_name, node);
  84. nodes.push(node);
  85. }
  86. this._nodes = nodes.map((node) => new xmodel.Node(metadata, node, arg));
  87. }
  88. get inputs() {
  89. return this._inputs;
  90. }
  91. get outputs() {
  92. return this._outputs;
  93. }
  94. get nodes() {
  95. return this._nodes;
  96. }
  97. };
  98. xmodel.Parameter = class {
  99. constructor(name, args) {
  100. this._name = name;
  101. this._arguments = args;
  102. }
  103. get name() {
  104. return this._name;
  105. }
  106. get visible() {
  107. return true;
  108. }
  109. get arguments() {
  110. return this._arguments;
  111. }
  112. };
  113. xmodel.Argument = class {
  114. constructor(name, node, initializer) {
  115. if (typeof name !== 'string') {
  116. throw new xmodel.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  117. }
  118. this._name = name;
  119. if (node) {
  120. const tensor = node.output_tensor;
  121. if (tensor && tensor.tensor_attr && tensor.data_type) {
  122. if (initializer) {
  123. this._initializer = new xmodel.Tensor(node);
  124. }
  125. else {
  126. this._type = new xmodel.TensorType(tensor);
  127. }
  128. }
  129. }
  130. }
  131. get name() {
  132. return this._name;
  133. }
  134. get type() {
  135. if (this._initializer) {
  136. return this._initializer.type;
  137. }
  138. return this._type;
  139. }
  140. get initializer() {
  141. return this._initializer;
  142. }
  143. };
  144. xmodel.Node = class {
  145. constructor(metadata, op_node, arg) {
  146. this._name = op_node.op_name || '';
  147. this._type = metadata.type(op_node.op_type);
  148. this._inputs = [];
  149. this._outputs = [];
  150. this._attributes = [];
  151. this._chain = [];
  152. if (op_node.op_attr) {
  153. for (const entry of Object.entries(op_node.op_attr)) {
  154. const name = entry[0];
  155. if (name === 'device') {
  156. this._device = entry[1].string_value;
  157. continue;
  158. }
  159. if (name === 'workload') {
  160. continue;
  161. }
  162. if (name.startsWith('quant_in_') || name.startsWith('quant_out_')) {
  163. continue;
  164. }
  165. const value = xmodel.Utility.attribute(entry[1]);
  166. if (name === 'nonlinear' && value.value && value.value !== 'NONE' && value.value !== 0) {
  167. let activation = value.value;
  168. if (typeof activation === 'string') {
  169. activation = activation.toLowerCase();
  170. }
  171. else if (Number.isInteger(activation) && activation < 5) {
  172. activation = [ 'none', 'relu', 'prelu', 'leakyrelu', 'relu6' ][activation];
  173. }
  174. else {
  175. activation = JSON.stringify(activation);
  176. }
  177. this._chain.push(new xmodel.Node(metadata, { op_type: activation }, arg));
  178. continue;
  179. }
  180. this._attributes.push(new xmodel.Attribute(metadata.attribute(this._type, name), name, value));
  181. }
  182. }
  183. if (op_node.args) {
  184. for (const input of op_node.args) {
  185. const args = input.arg_ops.map((arg_op) => arg(arg_op));
  186. this._inputs.push(new xmodel.Parameter(input.arg_name, args));
  187. }
  188. }
  189. if (op_node.op_name) {
  190. this._outputs.push(new xmodel.Parameter('output', [ arg(op_node.op_name) ]));
  191. }
  192. }
  193. get type() {
  194. return this._type;
  195. }
  196. get name() {
  197. return this._name;
  198. }
  199. get device() {
  200. return this._device;
  201. }
  202. get inputs() {
  203. return this._inputs;
  204. }
  205. get outputs() {
  206. return this._outputs;
  207. }
  208. get attributes() {
  209. return this._attributes;
  210. }
  211. get chain() {
  212. return this._chain;
  213. }
  214. };
  215. xmodel.Attribute = class {
  216. constructor(metadata, name, attribute) {
  217. this._name = name;
  218. this._type = attribute.type;
  219. this._value = attribute.value;
  220. if (metadata) {
  221. if (metadata.default !== undefined) {
  222. if (metadata.default === this._value) {
  223. this._visible = false;
  224. }
  225. if (Array.isArray(metadata.default) && Array.isArray(this._value) &&
  226. metadata.default.length === this._value.length && metadata.default.every((value, index) => value === this._value[index])) {
  227. this._visible = false;
  228. }
  229. }
  230. }
  231. }
  232. get name() {
  233. return this._name;
  234. }
  235. get type() {
  236. return this._type;
  237. }
  238. get value() {
  239. return this._value;
  240. }
  241. get visible() {
  242. return this._visible == false ? false : true;
  243. }
  244. };
  245. xmodel.TensorType = class {
  246. constructor(tensor) {
  247. switch (tensor.data_type) {
  248. case 0: this._dataType = 'int'; break;
  249. case 1: this._dataType = 'uint'; break;
  250. case 2: this._dataType = 'xint'; break;
  251. case 3: this._dataType = 'xuint'; break;
  252. case 4: this._dataType = 'float'; break;
  253. default: throw new xmodel.Error('...');
  254. }
  255. this._dataType += tensor.tensor_bit_width.toString();
  256. this._shape = new xmodel.TensorShape(tensor.tensor_dim);
  257. if (tensor.tensor_attr) {
  258. const attr = {};
  259. for (const entry of Object.entries(tensor.tensor_attr)) {
  260. const key = entry[0];
  261. const value = entry[1][entry[1].value];
  262. if (key.startsWith('quant_')) {
  263. continue;
  264. }
  265. attr[key] = value;
  266. const denotation = [];
  267. if (attr.fix_point !== undefined) {
  268. denotation.push(attr.fix_point.toString() + '.');
  269. }
  270. if (attr.round_mode !== undefined) {
  271. denotation.push(attr.round_mode.toString());
  272. }
  273. if (denotation.length > 0) {
  274. this._denotation = denotation.join(' ');
  275. }
  276. }
  277. }
  278. }
  279. get dataType() {
  280. return this._dataType;
  281. }
  282. get shape() {
  283. return this._shape;
  284. }
  285. get denotation() {
  286. return this._denotation;
  287. }
  288. toString() {
  289. return (this.dataType || '?') + this._shape.toString();
  290. }
  291. };
  292. xmodel.TensorShape = class {
  293. constructor(dimensions) {
  294. this._dimensions = Array.from(dimensions);
  295. }
  296. get dimensions() {
  297. return this._dimensions;
  298. }
  299. toString() {
  300. if (!this._dimensions || this._dimensions.length == 0) {
  301. return '';
  302. }
  303. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  304. }
  305. };
  306. xmodel.Tensor = class {
  307. constructor(node) {
  308. this._type = new xmodel.TensorType(node.output_tensor);
  309. this._kind = node.op_type;
  310. }
  311. get kind() {
  312. return this._kind;
  313. }
  314. get type() {
  315. return this._type;
  316. }
  317. get state() {
  318. return this._context().state || null;
  319. }
  320. get value() {
  321. const context = this._context();
  322. if (context.state) {
  323. return null;
  324. }
  325. context.limit = Number.MAX_SAFE_INTEGER;
  326. return this._decode(context, 0);
  327. }
  328. toString() {
  329. const context = this._context();
  330. if (context.state) {
  331. return '';
  332. }
  333. context.limit = 10000;
  334. const value = this._decode(context, 0);
  335. return JSON.stringify(value, null, 4);
  336. }
  337. _context() {
  338. const context = {};
  339. context.index = 0;
  340. context.count = 0;
  341. context.state = 'Tensor data not implemented.';
  342. return context;
  343. }
  344. _decode(/* context, dimension */) {
  345. return [];
  346. }
  347. };
  348. xmodel.Utility = class {
  349. static attribute(attr_value) {
  350. const key = attr_value.value;
  351. const type = key.replace(/_value$/, '');
  352. const value = attr_value[attr_value.value];
  353. switch (type) {
  354. case 'bool': return { type: 'boolean', value: value };
  355. case 'int32': return { type: 'int32', value: value };
  356. case 'int32_vec': return { type: 'int32[]', value: value.value };
  357. case 'int64': return { type: 'int64', value: value };
  358. case 'uint64': return { type: 'uint64', value: value };
  359. case 'float': return { type: 'float32', value: value };
  360. case 'float_vec': return { type: 'float32[]', value: value.value };
  361. case 'double': return { type: 'float64', value: value };
  362. case 'string': return { type: 'string', value: value };
  363. case 'string_vec': return { type: 'string[]', value: value.value };
  364. case 'bytes': return { type: 'byte[]', value: value.value };
  365. case 'map_string_2_int32': return { type: 'map<string,int32>', value: value.value };
  366. default: throw new xmodel.Error("Unsupported attribute type '" + type + "'.");
  367. }
  368. }
  369. };
  370. xmodel.Metadata = class {
  371. constructor(op_defs) {
  372. this._types = new Map();
  373. this._attributes = new Map();
  374. const categories = new Map([
  375. [ 'avgpool2d', 'Pool' ],
  376. [ 'batchnorm', 'Normalization' ],
  377. [ 'celu', 'Activation' ],
  378. [ 'concat-fix', 'Tensor' ],
  379. [ 'concat', 'Tensor' ],
  380. [ 'conv2d-fix', 'Layer' ],
  381. [ 'conv2d', 'Layer' ],
  382. [ 'depthwise-conv2d-fix', 'Layer' ],
  383. [ 'depthwise-conv2d', 'Layer' ],
  384. [ 'elu', 'Activation' ],
  385. [ 'fix', 'Quantization' ],
  386. [ 'fix2float', 'Quantization' ],
  387. [ 'flatten', 'Shape' ],
  388. [ 'float2fix', 'Quantization' ],
  389. [ 'gelu', 'Activation' ],
  390. [ 'hard-sigmoid', 'Activation' ],
  391. [ 'hard-sigmoid-fix', 'Activation' ],
  392. [ 'hard-swish', 'Activation' ],
  393. [ 'hard-tanh', 'Activation' ],
  394. [ 'identity', 'Control' ],
  395. [ 'inner-product', 'Layer' ],
  396. [ 'l2_normalize', 'Normalization' ],
  397. [ 'leaky-relu', 'Activation' ],
  398. [ 'leakyrelu', 'Activation' ],
  399. [ 'maxpool2d', 'Pool' ],
  400. [ 'pool-fix', 'Pool' ],
  401. [ 'relu', 'Activation' ],
  402. [ 'relu6', 'Activation' ],
  403. [ 'reshape-fix', 'Shape' ],
  404. [ 'reshape', 'Shape' ],
  405. [ 'scale', 'Layer' ],
  406. [ 'selu', 'Activation' ],
  407. [ 'shape', 'Shape' ],
  408. [ 'sigmoid', 'Activation' ],
  409. [ 'softmax', 'Activation' ],
  410. [ 'squeeze', 'Transform' ],
  411. [ 'stack', 'Tensor' ],
  412. [ 'strided_slice', 'Tensor' ],
  413. [ 'swish', 'Activation' ],
  414. [ 'tanh', 'Activation' ],
  415. [ 'threshold', 'Quantization' ],
  416. [ 'transpose', 'Tensor' ],
  417. [ 'transposed-conv2d', 'Layer' ],
  418. [ 'transposed-conv2d-fix', 'Layer' ],
  419. [ 'transposed-depthwise-conv2d', 'Layer' ],
  420. [ 'transposed-depthwise-conv2d-fix', 'Layer' ],
  421. [ 'upsample-fix', 'Data' ],
  422. ]);
  423. for (const op_def of op_defs) {
  424. const name = op_def.name;
  425. const metadata = {};
  426. metadata.name = name;
  427. if (op_def.annotation) {
  428. metadata.description = op_def.annotation;
  429. }
  430. metadata.inputs = op_def.input_args.map((input_arg) => {
  431. const input = {};
  432. input.name = input_arg.name;
  433. if (input_arg.annotation) {
  434. input.description = input_arg.annotation;
  435. }
  436. return input;
  437. });
  438. metadata.attributes = op_def.attrs.map((attr) => {
  439. const attribute = {};
  440. attribute.name = attr.name;
  441. const value = xmodel.Utility.attribute(attr.default_value);
  442. attribute.default = value.value;
  443. if (attr.annotation) {
  444. attribute.description = attr.annotation;
  445. }
  446. this._attributes.set(name + ':' + attr.name, attribute);
  447. return attribute;
  448. });
  449. if (categories.has(name)) {
  450. metadata.category = categories.get(name);
  451. }
  452. this._types.set(name, metadata);
  453. }
  454. for (const entry of categories) {
  455. const name = entry[0];
  456. const category = entry[1];
  457. if (!this._types.has(name)) {
  458. this._types.set(name, { name: name, category: category });
  459. }
  460. }
  461. }
  462. type(name) {
  463. if (!this._types.has(name)) {
  464. this._types.set(name, { name: name });
  465. }
  466. return this._types.get(name);
  467. }
  468. attribute(type, name) {
  469. const key = type + ':' + name;
  470. return this._attributes.get(key);
  471. }
  472. };
  473. xmodel.Error = class extends Error {
  474. constructor(message) {
  475. super(message);
  476. this.name = 'Error loading xmodel.';
  477. }
  478. };
  479. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  480. module.exports.ModelFactory = xmodel.ModelFactory;
  481. }