xmodel.js 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. /* jshint esversion: 6 */
  2. var xmodel = xmodel || {};
  3. var protobuf = protobuf || require('./protobuf');
  4. xmodel.ModelFactory = class {
  5. match(context) {
  6. const tags = context.tags('pb');
  7. if (tags.get(5) === 2) {
  8. return true;
  9. }
  10. return false;
  11. }
  12. open(context) {
  13. return context.require('./xmodel-proto').then(() => {
  14. try {
  15. xmodel.proto = protobuf.get('xmodel').serial_v2;
  16. const stream = context.stream;
  17. const reader = protobuf.BinaryReader.open(stream);
  18. const graph = xmodel.proto.Graph.decode(reader);
  19. return new xmodel.Model(graph);
  20. }
  21. catch (error) {
  22. const message = error && error.message ? error.message : error.toString();
  23. throw new xmodel.Error('File format is not serial_v2.Graph (' + message.replace(/\.$/, '') + ').');
  24. }
  25. });
  26. }
  27. };
  28. xmodel.Model = class {
  29. constructor(graph) {
  30. this._name = graph.graph_name || '';
  31. this._format = 'Vitis-AI xmodel';
  32. this._producer = graph && graph.graph_attr && graph.graph_attr.origin && graph.graph_attr.origin.string_value ? graph.graph_attr.origin.string_value : '';
  33. this._graphs = [ new xmodel.Graph(graph) ];
  34. }
  35. get name() {
  36. return this._name;
  37. }
  38. get format() {
  39. return this._format;
  40. }
  41. get producer() {
  42. return this._producer;
  43. }
  44. get graphs() {
  45. return this._graphs;
  46. }
  47. };
  48. xmodel.Graph = class {
  49. constructor(graph) {
  50. const metadata = new xmodel.Metadata(graph.op_defs);
  51. this._inputs = [];
  52. this._outputs = [];
  53. const count = new Map();
  54. for (const op_node of graph.op_node) {
  55. for (const arg of op_node.args) {
  56. for (const arg_op of arg.arg_ops) {
  57. count.set(arg_op, count.has(arg_op) ? count.get(arg_op) + 1 : 1);
  58. }
  59. }
  60. }
  61. const initializers = new Map();
  62. const nodes = [];
  63. for (const op_node of graph.op_node) {
  64. if (op_node.op_type === 'const-fix' && op_node.args.length === 0 && count.get(op_node.op_name) === 1) {
  65. const type = xmodel.Utility.type(op_node.op_attr);
  66. initializers.set(op_node.op_name, new xmodel.Tensor(type, op_node.op_type));
  67. continue;
  68. }
  69. if (op_node.op_type === 'data-fix' && op_node.args.length === 0) {
  70. const type = xmodel.Utility.type(op_node.op_attr);
  71. const quantization = xmodel.Utility.quantization(op_node.op_attr);
  72. this._inputs.push(new xmodel.Parameter(op_node.op_name, [
  73. new xmodel.Argument(op_node.op_name, type, quantization.out, null)
  74. ]));
  75. continue;
  76. }
  77. nodes.push(op_node);
  78. }
  79. this._nodes = nodes.map((node) => new xmodel.Node(metadata, node, initializers));
  80. }
  81. get inputs() {
  82. return this._inputs;
  83. }
  84. get outputs() {
  85. return this._outputs;
  86. }
  87. get nodes() {
  88. return this._nodes;
  89. }
  90. };
  91. xmodel.Parameter = class {
  92. constructor(name, args) {
  93. this._name = name;
  94. this._arguments = args;
  95. }
  96. get name() {
  97. return this._name;
  98. }
  99. get visible() {
  100. return true;
  101. }
  102. get arguments() {
  103. return this._arguments;
  104. }
  105. };
  106. xmodel.Argument = class {
  107. constructor(name, type, quantization, initializer) {
  108. if (typeof name !== 'string') {
  109. throw new xmodel.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  110. }
  111. this._name = name;
  112. this._type = type || null;
  113. this._quantization = quantization;
  114. this._initializer = initializer || null;
  115. }
  116. get name() {
  117. return this._name;
  118. }
  119. get type() {
  120. if (this._initializer) {
  121. return this._initializer.type;
  122. }
  123. return this._type;
  124. }
  125. get quantization() {
  126. if (this._quantization) {
  127. const list = [];
  128. if (this._quantization.bit_width !== undefined) {
  129. list.push('bit_width: ' + this._quantization.bit_width);
  130. }
  131. if (this._quantization.pos !== undefined) {
  132. list.push('pos: ' + this._quantization.pos);
  133. }
  134. if (this._quantization.signed !== undefined) {
  135. list.push('signed: ' + this._quantization.signed);
  136. }
  137. if (this._quantization.round_mode !== undefined) {
  138. list.push('round_mode: ' + this._quantization.round_mode);
  139. }
  140. return list.join(', ');
  141. }
  142. return null;
  143. }
  144. get initializer() {
  145. return this._initializer;
  146. }
  147. };
  148. xmodel.Node = class {
  149. constructor(metadata, op_node, initializers) {
  150. this._name = op_node.op_name || '';
  151. this._type = metadata.type(op_node.op_type) || { name: op_node.op_type };
  152. this._inputs = [];
  153. this._outputs = [];
  154. this._attributes = [];
  155. for (const name of Object.keys(op_node.op_attr)) {
  156. if (!name.startsWith('quant_in_') && !name.startsWith('quant_out_') && name !== 'workload') {
  157. const attribute = xmodel.Utility.attribute(op_node.op_attr[name]);
  158. this._attributes.push(new xmodel.Attribute(metadata.attribute(this._type, name), name, attribute.type, attribute.value));
  159. }
  160. }
  161. const quantization = xmodel.Utility.quantization(op_node.op_attr);
  162. for (const arg of op_node.args) {
  163. const args = arg.arg_ops.map((arg_op) => new xmodel.Argument(arg_op, null, quantization.in, initializers.get(arg_op)));
  164. this._inputs.push(new xmodel.Parameter(arg.arg_name, args));
  165. }
  166. this._outputs.push(new xmodel.Parameter('output', [
  167. new xmodel.Argument(op_node.op_name, null, quantization.out, null)
  168. ]));
  169. }
  170. get type() {
  171. return this._type;
  172. }
  173. get name() {
  174. return this._name;
  175. }
  176. get inputs() {
  177. return this._inputs;
  178. }
  179. get outputs() {
  180. return this._outputs;
  181. }
  182. get attributes() {
  183. return this._attributes;
  184. }
  185. };
  186. xmodel.Attribute = class {
  187. constructor(metadata, name, type, value) {
  188. this._name = name;
  189. this._type = type;
  190. this._value = value;
  191. if (metadata) {
  192. if (metadata.default !== undefined) {
  193. if (metadata.default === this._value) {
  194. this._visible = false;
  195. }
  196. if (Array.isArray(metadata.default) && Array.isArray(this._value) &&
  197. metadata.default.length === this._value.length && metadata.default.every((value, index) => value === this._value[index])) {
  198. this._visible = false;
  199. }
  200. }
  201. }
  202. }
  203. get name() {
  204. return this._name;
  205. }
  206. get type() {
  207. return this._type;
  208. }
  209. get value() {
  210. return this._value;
  211. }
  212. get visible() {
  213. return this._visible == false ? false : true;
  214. }
  215. };
  216. xmodel.TensorType = class {
  217. constructor(dataType, shape) {
  218. this._dataType = dataType || '?';
  219. this._shape = shape;
  220. }
  221. get dataType() {
  222. return this._dataType;
  223. }
  224. get shape() {
  225. return this._shape;
  226. }
  227. toString() {
  228. return (this.dataType || '?') + this._shape.toString();
  229. }
  230. };
  231. xmodel.TensorShape = class {
  232. constructor(dimensions) {
  233. this._dimensions = Array.from(dimensions);
  234. }
  235. get dimensions() {
  236. return this._dimensions;
  237. }
  238. toString() {
  239. if (!this._dimensions || this._dimensions.length == 0) {
  240. return '';
  241. }
  242. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  243. }
  244. };
  245. xmodel.Tensor = class {
  246. constructor(type, kind) {
  247. this._type = type;
  248. this._kind = kind;
  249. }
  250. get kind() {
  251. return this._kind;
  252. }
  253. get type() {
  254. return this._type;
  255. }
  256. get state() {
  257. return this._context().state || null;
  258. }
  259. get value() {
  260. const context = this._context();
  261. if (context.state) {
  262. return null;
  263. }
  264. context.limit = Number.MAX_SAFE_INTEGER;
  265. return this._decode(context, 0);
  266. }
  267. toString() {
  268. const context = this._context();
  269. if (context.state) {
  270. return '';
  271. }
  272. context.limit = 10000;
  273. const value = this._decode(context, 0);
  274. return JSON.stringify(value, null, 4);
  275. }
  276. _context() {
  277. const context = {};
  278. context.index = 0;
  279. context.count = 0;
  280. context.state = 'Tensor data not implemented.';
  281. return context;
  282. }
  283. _decode(/* context, dimension */) {
  284. return [];
  285. }
  286. };
  287. xmodel.Utility = class {
  288. static attribute(attr) {
  289. const key = attr.value;
  290. const value = {
  291. value: attr[key],
  292. type: key.replace(/_value$/, '')
  293. };
  294. switch (value.type) {
  295. case 'bool': {
  296. value.type = 'boolean';
  297. break;
  298. }
  299. case 'int32_vec': {
  300. value.type = 'int32[]';
  301. value.value = value.value.value;
  302. break;
  303. }
  304. }
  305. return value;
  306. }
  307. static type(attr) {
  308. let dataType = '?';
  309. const data_type = attr.data_type.string_value;
  310. switch (data_type) {
  311. case 'XINT8': dataType = 'int8'; break;
  312. default: throw new xmodel.Error("Unknown data_type '" + data_type + "'.");
  313. }
  314. const shape = attr.shape.int32_vec_value.value;
  315. return new xmodel.TensorType(dataType, new xmodel.TensorShape(shape));
  316. }
  317. static quantization(attr) {
  318. const quant = { in: {}, out: {} };
  319. for (const name of Object.keys(attr)) {
  320. const attribute = xmodel.Utility.attribute(attr[name]);
  321. switch (name) {
  322. case 'quant_in_bit_width':
  323. quant.in.bit_width = attribute.value;
  324. break;
  325. case 'quant_in_quantize_pos':
  326. quant.in.pos = attribute.value;
  327. break;
  328. case 'quant_in_signed':
  329. quant.in.signed = attribute.value;
  330. break;
  331. case 'quant_in_round_mode':
  332. quant.in.round_mode = attribute.value;
  333. break;
  334. case 'quant_out_bit_width':
  335. quant.out.bit_width = attribute.value;
  336. break;
  337. case 'quant_out_quantize_pos':
  338. quant.out.pos = attribute.value;
  339. break;
  340. case 'quant_out_signed':
  341. quant.out.signed = attribute.value;
  342. break;
  343. case 'quant_out_round_mode':
  344. quant.out.round_mode = attribute.value;
  345. break;
  346. }
  347. }
  348. return quant;
  349. }
  350. };
  351. xmodel.Metadata = class {
  352. constructor(op_defs) {
  353. this._map = new Map();
  354. this._attributeCache = new Map();
  355. const categories = new Map([
  356. [ 'conv2d-fix', 'Layer' ],
  357. [ 'depthwise-conv2d-fix', 'Layer' ],
  358. [ 'upsample-fix', 'Layer' ],
  359. [ 'pool-fix', 'Pool' ],
  360. [ 'batchnorm', 'Normalization' ],
  361. [ 'concat-fix', 'Tensor' ],
  362. [ 'reshape-fix', 'Shape' ],
  363. [ 'softmax', 'Activation' ]
  364. ]);
  365. for (const op_def of op_defs) {
  366. const name = op_def.name;
  367. const schema = {};
  368. schema.name = name;
  369. if (op_def.annotation) {
  370. schema.description = op_def.annotation;
  371. }
  372. schema.inputs = op_def.input_args.map((input_arg) => {
  373. const input = {};
  374. input.name = input_arg.name;
  375. if (input_arg.annotation) {
  376. input.description = input_arg.annotation;
  377. }
  378. return input;
  379. });
  380. schema.attributes = op_def.attrs.map((attr) => {
  381. const attribute = {};
  382. attribute.name = attr.name;
  383. const value = xmodel.Utility.attribute(attr.default_value);
  384. attribute.default = value.value;
  385. if (attr.annotation) {
  386. attribute.description = attr.annotation;
  387. }
  388. this._attributeCache.set(name + ':' + attr.name, attribute);
  389. return attribute;
  390. });
  391. if (categories.has(name)) {
  392. schema.category = categories.get(name);
  393. }
  394. this._map.set(name, schema);
  395. }
  396. }
  397. type(name) {
  398. return this._map.get(name);
  399. }
  400. attribute(type, name) {
  401. const key = type + ':' + name;
  402. return this._attributeCache.get(key);
  403. }
  404. };
  405. xmodel.Error = class extends Error {
  406. constructor(message) {
  407. super(message);
  408. this.name = 'Error loading xmodel.';
  409. }
  410. };
  411. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  412. module.exports.ModelFactory = xmodel.ModelFactory;
  413. }