xmodel.js 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. const xmodel = {};
  2. xmodel.ModelFactory = class {
  3. match(context) {
  4. const tags = context.tags('pb');
  5. if (tags.get(5) === 2) {
  6. context.type = 'xmodel.pb';
  7. }
  8. }
  9. async open(context) {
  10. xmodel.proto = await context.require('./xmodel-proto');
  11. xmodel.proto = xmodel.proto.serial_v2;
  12. let graph = null;
  13. try {
  14. const reader = context.read('protobuf.binary');
  15. graph = xmodel.proto.Graph.decode(reader);
  16. } catch (error) {
  17. const message = error && error.message ? error.message : error.toString();
  18. throw new xmodel.Error(`File format is not serial_v2.Graph (${message.replace(/\.$/, '')}).`);
  19. }
  20. return new xmodel.Model(graph);
  21. }
  22. };
  23. xmodel.Model = class {
  24. constructor(graph) {
  25. this.name = graph.graph_name || '';
  26. this.format = 'xmodel';
  27. this.producer = graph && graph.graph_attr && graph.graph_attr.origin && graph.graph_attr.origin.string_value ? graph.graph_attr.origin.string_value : '';
  28. this.graphs = [ new xmodel.Graph(graph) ];
  29. }
  30. };
  31. xmodel.Graph = class {
  32. constructor(graph) {
  33. const metadata = new xmodel.Metadata(graph.op_defs);
  34. this.inputs = [];
  35. this.outputs = [];
  36. const counts = new Map();
  37. for (const op_node of graph.op_node) {
  38. for (const arg of op_node.args) {
  39. for (const arg_op of arg.arg_ops) {
  40. counts.set(arg_op, counts.has(arg_op) ? counts.get(arg_op) + 1 : 1);
  41. }
  42. }
  43. }
  44. const values = new Map();
  45. values.map = (name, node, initializer) => {
  46. if (!values.has(name)) {
  47. values.set(name, new xmodel.Value(name, node, initializer));
  48. }
  49. return values.get(name);
  50. };
  51. const nodes = [];
  52. for (const node of graph.op_node) {
  53. if (node.args.length === 0) {
  54. if (node.op_type === 'data' || node.op_type === 'data-fix') {
  55. const value = values.map(node.op_name, node);
  56. this.inputs.push(new xmodel.Argument(node.op_name, [ value ]));
  57. continue;
  58. }
  59. }
  60. if (node.args.length === 0 && counts.get(node.op_name) === 1) {
  61. if (node.op_type === 'const' || node.op_type === 'const-fix') {
  62. values.map(node.op_name, node, true);
  63. continue;
  64. }
  65. }
  66. values.map(node.op_name, node);
  67. nodes.push(node);
  68. }
  69. this.nodes = nodes.map((node) => new xmodel.Node(metadata, node, values));
  70. }
  71. };
  72. xmodel.Argument = class {
  73. constructor(name, value) {
  74. this.name = name;
  75. this.value = value;
  76. }
  77. };
  78. xmodel.Value = class {
  79. constructor(name, node, initializer) {
  80. if (typeof name !== 'string') {
  81. throw new xmodel.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  82. }
  83. this.name = name;
  84. if (node) {
  85. const tensor = node.output_tensor;
  86. if (tensor && tensor.tensor_attr && tensor.data_type) {
  87. if (initializer) {
  88. this.initializer = new xmodel.Tensor(node);
  89. this.type = this.initializer.type;
  90. } else {
  91. this.type = new xmodel.TensorType(tensor);
  92. }
  93. }
  94. }
  95. }
  96. };
  97. xmodel.Node = class {
  98. constructor(metadata, op_node, values) {
  99. this.name = op_node.op_name || '';
  100. this.type = metadata.type(op_node.op_type);
  101. this.inputs = [];
  102. this.outputs = [];
  103. this.attributes = [];
  104. this.chain = [];
  105. if (op_node.op_attr) {
  106. for (const [name, obj] of Object.entries(op_node.op_attr)) {
  107. if (name === 'device') {
  108. this.device = obj.string_value;
  109. continue;
  110. }
  111. if (name === 'workload') {
  112. continue;
  113. }
  114. if (name.startsWith('quant_in_') || name.startsWith('quant_out_')) {
  115. continue;
  116. }
  117. const value = xmodel.Utility.attribute(obj);
  118. if (name === 'nonlinear' && value.value && value.value !== 'NONE' && value.value !== 0) {
  119. let activation = value.value;
  120. if (typeof activation === 'string') {
  121. activation = activation.toLowerCase();
  122. } else if (Number.isInteger(activation) && activation < 5) {
  123. activation = [ 'none', 'relu', 'prelu', 'leakyrelu', 'relu6' ][activation];
  124. } else {
  125. activation = JSON.stringify(activation);
  126. }
  127. this.chain.push(new xmodel.Node(metadata, { op_type: activation }, values));
  128. continue;
  129. }
  130. const attribute = new xmodel.Attribute(metadata.attribute(this.type, name), name, value);
  131. this.attributes.push(attribute);
  132. }
  133. }
  134. if (op_node.args) {
  135. for (const input of op_node.args) {
  136. const args = input.arg_ops.map((arg_op) => values.map(arg_op));
  137. const argument = new xmodel.Argument(input.arg_name, args);
  138. this.inputs.push(argument);
  139. }
  140. }
  141. if (op_node.op_name) {
  142. const argument = new xmodel.Argument('output', [ values.map(op_node.op_name) ]);
  143. this.outputs.push(argument);
  144. }
  145. }
  146. };
  147. xmodel.Attribute = class {
  148. constructor(metadata, name, attribute) {
  149. this.name = name;
  150. this.type = attribute.type;
  151. this.value = attribute.value;
  152. if (metadata) {
  153. if (metadata.default !== undefined) {
  154. if (metadata.default === this.value) {
  155. this.visible = false;
  156. }
  157. if (Array.isArray(metadata.default) && Array.isArray(this.value) &&
  158. metadata.default.length === this.value.length && metadata.default.every((value, index) => value === this.value[index])) {
  159. this.visible = false;
  160. }
  161. }
  162. }
  163. }
  164. };
  165. xmodel.TensorType = class {
  166. constructor(tensor) {
  167. switch (tensor.data_type) {
  168. case 0: this.dataType = 'int'; break;
  169. case 1: this.dataType = 'uint'; break;
  170. case 2: this.dataType = 'xint'; break;
  171. case 4: this.dataType = 'float'; break;
  172. case 3: this.dataType = 'xuint'; break;
  173. default: throw new xmodel.Error(`Unsupported data type '${tensor.data_type}'.`);
  174. }
  175. this.dataType += tensor.tensor_bit_width.toString();
  176. this.shape = new xmodel.TensorShape(tensor.tensor_dim);
  177. if (tensor.tensor_attr) {
  178. const attr = {};
  179. for (const [key, obj] of Object.entries(tensor.tensor_attr)) {
  180. const value = obj[obj.value];
  181. if (key.startsWith('quant_')) {
  182. continue;
  183. }
  184. attr[key] = value;
  185. const denotation = [];
  186. if (attr.fix_point !== undefined) {
  187. denotation.push(`${attr.fix_point}.`);
  188. }
  189. if (attr.round_mode !== undefined) {
  190. denotation.push(attr.round_mode.toString());
  191. }
  192. if (denotation.length > 0) {
  193. this.denotation = denotation.join(' ');
  194. }
  195. }
  196. }
  197. }
  198. toString() {
  199. return (this.dataType || '?') + this.shape.toString();
  200. }
  201. };
  202. xmodel.TensorShape = class {
  203. constructor(dimensions) {
  204. this.dimensions = Array.from(dimensions);
  205. }
  206. toString() {
  207. if (!this.dimensions || this.dimensions.length == 0) {
  208. return '';
  209. }
  210. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  211. }
  212. };
  213. xmodel.Tensor = class {
  214. constructor(node) {
  215. this.type = new xmodel.TensorType(node.output_tensor);
  216. this.category = node.op_type;
  217. if (node.op_attr && node.op_attr.data) {
  218. const data = node.op_attr.data;
  219. if (data.bytes_value && data.bytes_value.value) {
  220. this.encoding = '<';
  221. this.values = data.bytes_value.value;
  222. }
  223. }
  224. }
  225. };
  226. xmodel.Utility = class {
  227. static attribute(attr_value) {
  228. const key = attr_value.value;
  229. const type = key.replace(/_value$/, '');
  230. const value = attr_value[attr_value.value];
  231. switch (type) {
  232. case 'bool': return { type: 'boolean', value: value };
  233. case 'int32': return { type: 'int32', value: value };
  234. case 'int32_vec': return { type: 'int32[]', value: value.value };
  235. case 'uint32': return { type: 'uint32', value: value };
  236. case 'uint32_vec': return { type: 'uint32[]', value: value.value };
  237. case 'int64': return { type: 'int64', value: value };
  238. case 'uint64': return { type: 'uint64', value: value };
  239. case 'float': return { type: 'float32', value: value };
  240. case 'float_vec': return { type: 'float32[]', value: value.value };
  241. case 'double': return { type: 'float64', value: value };
  242. case 'string': return { type: 'string', value: value };
  243. case 'string_vec': return { type: 'string[]', value: value.value };
  244. case 'bytes': return { type: 'byte[]', value: value.value };
  245. case 'map_string_2_int32': return { type: 'map<string,int32>', value: value.value };
  246. default: throw new xmodel.Error(`Unsupported attribute type '${type}'.`);
  247. }
  248. }
  249. };
  250. xmodel.Metadata = class {
  251. constructor(op_defs) {
  252. this._types = new Map();
  253. this._attributes = new Map();
  254. const categories = [
  255. [ 'avgpool2d', 'Pool' ],
  256. [ 'batchnorm', 'Normalization' ],
  257. [ 'celu', 'Activation' ],
  258. [ 'concat-fix', 'Tensor' ],
  259. [ 'concat', 'Tensor' ],
  260. [ 'conv2d-fix', 'Layer' ],
  261. [ 'conv2d', 'Layer' ],
  262. [ 'depthwise-conv2d-fix', 'Layer' ],
  263. [ 'depthwise-conv2d', 'Layer' ],
  264. [ 'elu', 'Activation' ],
  265. [ 'fix', 'Quantization' ],
  266. [ 'fix2float', 'Quantization' ],
  267. [ 'flatten', 'Shape' ],
  268. [ 'float2fix', 'Quantization' ],
  269. [ 'gelu', 'Activation' ],
  270. [ 'hard-sigmoid', 'Activation' ],
  271. [ 'hard-sigmoid-fix', 'Activation' ],
  272. [ 'hard-swish', 'Activation' ],
  273. [ 'hard-tanh', 'Activation' ],
  274. [ 'identity', 'Control' ],
  275. [ 'inner-product', 'Layer' ],
  276. [ 'l2_normalize', 'Normalization' ],
  277. [ 'leaky-relu', 'Activation' ],
  278. [ 'leakyrelu', 'Activation' ],
  279. [ 'maxpool2d', 'Pool' ],
  280. [ 'pool-fix', 'Pool' ],
  281. [ 'relu', 'Activation' ],
  282. [ 'relu6', 'Activation' ],
  283. [ 'reshape-fix', 'Shape' ],
  284. [ 'reshape', 'Shape' ],
  285. [ 'scale', 'Layer' ],
  286. [ 'selu', 'Activation' ],
  287. [ 'shape', 'Shape' ],
  288. [ 'sigmoid', 'Activation' ],
  289. [ 'softmax', 'Activation' ],
  290. [ 'squeeze', 'Transform' ],
  291. [ 'stack', 'Tensor' ],
  292. [ 'strided_slice', 'Tensor' ],
  293. [ 'swish', 'Activation' ],
  294. [ 'tanh', 'Activation' ],
  295. [ 'threshold', 'Quantization' ],
  296. [ 'transpose', 'Tensor' ],
  297. [ 'transposed-conv2d', 'Layer' ],
  298. [ 'transposed-conv2d-fix', 'Layer' ],
  299. [ 'transposed-depthwise-conv2d', 'Layer' ],
  300. [ 'transposed-depthwise-conv2d-fix', 'Layer' ],
  301. [ 'upsample-fix', 'Data' ],
  302. ];
  303. this._types = new Map(categories.map(([name, category]) => [ name, { name: name, category: category } ]));
  304. for (const op_def of op_defs) {
  305. const type = this._types.get(op_def.name) || { name: op_def.name };
  306. if (op_def.annotation) {
  307. type.description = op_def.annotation;
  308. }
  309. type.inputs = op_def.input_args.map((input_arg) => {
  310. const input = {};
  311. input.name = input_arg.name;
  312. if (input_arg.annotation) {
  313. input.description = input_arg.annotation;
  314. }
  315. return input;
  316. });
  317. type.attributes = op_def.attrs.map((attr) => {
  318. const attribute = {};
  319. attribute.name = attr.name;
  320. attribute.default = xmodel.Utility.attribute(attr.default_value).value;
  321. if (attr.annotation) {
  322. attribute.description = attr.annotation;
  323. }
  324. return attribute;
  325. });
  326. for (const attribute of type.attributes) {
  327. this._attributes.set(`${type.name}:${attribute.name}`, attribute);
  328. }
  329. this._types.set(type.name, type);
  330. }
  331. }
  332. type(name) {
  333. if (!this._types.has(name)) {
  334. this._types.set(name, { name: name });
  335. }
  336. return this._types.get(name);
  337. }
  338. attribute(type, name) {
  339. const key = `${type}:${name}`;
  340. return this._attributes.get(key);
  341. }
  342. };
  343. xmodel.Error = class extends Error {
  344. constructor(message) {
  345. super(message);
  346. this.name = 'Error loading xmodel.';
  347. }
  348. };
  349. export const ModelFactory = xmodel.ModelFactory;