xmodel.js 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. const xmodel = {};
  2. xmodel.ModelFactory = class {
  3. async match(context) {
  4. const tags = await context.tags('pb');
  5. if (tags.get(5) === 2) {
  6. return context.set('xmodel.pb');
  7. }
  8. return null;
  9. }
  10. async open(context) {
  11. xmodel.proto = await context.require('./xmodel-proto');
  12. xmodel.proto = xmodel.proto.serial_v2;
  13. let graph = null;
  14. try {
  15. const reader = await context.read('protobuf.binary');
  16. graph = xmodel.proto.Graph.decode(reader);
  17. } catch (error) {
  18. const message = error && error.message ? error.message : error.toString();
  19. throw new xmodel.Error(`File format is not serial_v2.Graph (${message.replace(/\.$/, '')}).`);
  20. }
  21. return new xmodel.Model(graph);
  22. }
  23. };
  24. xmodel.Model = class {
  25. constructor(graph) {
  26. this.name = graph.graph_name || '';
  27. this.format = 'xmodel';
  28. this.producer = graph && graph.graph_attr && graph.graph_attr.origin && graph.graph_attr.origin.string_value ? graph.graph_attr.origin.string_value : '';
  29. this.modules = [new xmodel.Graph(graph)];
  30. }
  31. };
  32. xmodel.Graph = class {
  33. constructor(graph) {
  34. const metadata = new xmodel.Metadata(graph.op_defs);
  35. this.inputs = [];
  36. this.outputs = [];
  37. const counts = new Map();
  38. for (const op_node of graph.op_node) {
  39. for (const arg of op_node.args) {
  40. for (const arg_op of arg.arg_ops) {
  41. counts.set(arg_op, counts.has(arg_op) ? counts.get(arg_op) + 1 : 1);
  42. }
  43. }
  44. }
  45. const values = new Map();
  46. values.map = (name, node, initializer) => {
  47. if (!values.has(name)) {
  48. values.set(name, new xmodel.Value(name, node, initializer));
  49. }
  50. return values.get(name);
  51. };
  52. const nodes = [];
  53. for (const node of graph.op_node) {
  54. if (node.args.length === 0) {
  55. if (node.op_type === 'data' || node.op_type === 'data-fix') {
  56. const value = values.map(node.op_name, node);
  57. this.inputs.push(new xmodel.Argument(node.op_name, [value]));
  58. continue;
  59. }
  60. }
  61. if (node.args.length === 0 && counts.get(node.op_name) === 1) {
  62. if (node.op_type === 'const' || node.op_type === 'const-fix') {
  63. values.map(node.op_name, node, true);
  64. continue;
  65. }
  66. }
  67. values.map(node.op_name, node);
  68. nodes.push(node);
  69. }
  70. this.nodes = nodes.map((node) => new xmodel.Node(metadata, node, values));
  71. }
  72. };
  73. xmodel.Argument = class {
  74. constructor(name, value, type = null, visible = true) {
  75. this.name = name;
  76. this.value = value;
  77. this.type = type;
  78. this.visible = visible;
  79. }
  80. };
  81. xmodel.Value = class {
  82. constructor(name, node, initializer) {
  83. if (typeof name !== 'string') {
  84. throw new xmodel.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  85. }
  86. this.name = name;
  87. if (node) {
  88. const tensor = node.output_tensor;
  89. if (tensor && tensor.tensor_attr && tensor.data_type) {
  90. if (initializer) {
  91. this.initializer = new xmodel.Tensor(node);
  92. this.type = this.initializer.type;
  93. } else {
  94. this.type = new xmodel.TensorType(tensor);
  95. }
  96. }
  97. }
  98. }
  99. };
  100. xmodel.Node = class {
  101. constructor(metadata, op_node, values) {
  102. this.name = op_node.op_name || '';
  103. this.type = metadata.type(op_node.op_type);
  104. this.inputs = [];
  105. this.outputs = [];
  106. this.attributes = [];
  107. this.chain = [];
  108. if (op_node.op_attr) {
  109. for (const [name, obj] of Object.entries(op_node.op_attr)) {
  110. if (name === 'device') {
  111. this.device = obj.string_value;
  112. } else if (name !== 'workload' && !name.startsWith('quant_in_') && !name.startsWith('quant_out_')) {
  113. const attr = xmodel.Utility.attribute(obj);
  114. if (name === 'nonlinear' && attr.value && attr.value !== 'NONE' && attr.value !== 0) {
  115. let activation = attr.value;
  116. if (typeof activation === 'string') {
  117. activation = activation.toLowerCase();
  118. } else if (Number.isInteger(activation) && activation < 5) {
  119. activation = ['none', 'relu', 'prelu', 'leakyrelu', 'relu6'][activation];
  120. } else {
  121. activation = JSON.stringify(activation);
  122. }
  123. const node = new xmodel.Node(metadata, { op_type: activation }, values);
  124. this.chain.push(node);
  125. } else {
  126. const schema = metadata.attribute(this.type.name, name);
  127. const visible = (schema && schema.default !== undefined && schema.default === attr.value) ||
  128. (schema && Array.isArray(schema.default) && Array.isArray(this.value) && schema.default.length === attr.value.length && schema.default.every((value, index) => value === attr.value[index])) ? false : true;
  129. const attribute = new xmodel.Argument(name, attr.value, attr.type, visible);
  130. this.attributes.push(attribute);
  131. }
  132. }
  133. }
  134. }
  135. if (op_node.args) {
  136. for (const input of op_node.args) {
  137. const argument = new xmodel.Argument(input.arg_name, input.arg_ops.map((arg_op) => values.map(arg_op)));
  138. this.inputs.push(argument);
  139. }
  140. }
  141. if (op_node.op_name) {
  142. const argument = new xmodel.Argument('output', [values.map(op_node.op_name)]);
  143. this.outputs.push(argument);
  144. }
  145. }
  146. };
  147. xmodel.TensorType = class {
  148. constructor(tensor) {
  149. let type = '';
  150. switch (tensor.data_type) {
  151. case 0: type = 'int'; break;
  152. case 1: type = 'uint'; break;
  153. case 2: type = 'xint'; break;
  154. case 3: type = 'xuint'; break;
  155. case 4: type = 'float'; break;
  156. case 5: type = 'bfloat'; break;
  157. default: throw new xmodel.Error(`Unsupported data type '${tensor.data_type}'.`);
  158. }
  159. this.dataType = type + tensor.tensor_bit_width.toString();
  160. this.shape = new xmodel.TensorShape(tensor.tensor_dim);
  161. if (tensor.tensor_attr) {
  162. const attr = {};
  163. for (const [key, obj] of Object.entries(tensor.tensor_attr)) {
  164. const value = obj[obj.value];
  165. if (key.startsWith('quant_')) {
  166. continue;
  167. }
  168. attr[key] = value;
  169. const denotation = [];
  170. if (attr.fix_point !== undefined) {
  171. denotation.push(`${attr.fix_point}.`);
  172. }
  173. if (attr.round_mode !== undefined) {
  174. denotation.push(attr.round_mode.toString());
  175. }
  176. if (denotation.length > 0) {
  177. this.denotation = denotation.join(' ');
  178. }
  179. }
  180. }
  181. }
  182. toString() {
  183. return (this.dataType || '?') + this.shape.toString();
  184. }
  185. };
  186. xmodel.TensorShape = class {
  187. constructor(dimensions) {
  188. this.dimensions = Array.from(dimensions);
  189. }
  190. toString() {
  191. if (!this.dimensions || this.dimensions.length === 0) {
  192. return '';
  193. }
  194. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  195. }
  196. };
  197. xmodel.Tensor = class {
  198. constructor(node) {
  199. this.type = new xmodel.TensorType(node.output_tensor);
  200. this.category = node.op_type;
  201. if (node.op_attr && node.op_attr.data) {
  202. const data = node.op_attr.data;
  203. if (data.bytes_value && data.bytes_value.value) {
  204. this.encoding = '<';
  205. this.values = data.bytes_value.value;
  206. }
  207. }
  208. }
  209. };
  210. xmodel.Utility = class {
  211. static attribute(attr_value) {
  212. const key = attr_value.value;
  213. const type = key.replace(/_value$/, '');
  214. const value = attr_value[attr_value.value];
  215. switch (type) {
  216. case 'bool': return { type: 'boolean', value };
  217. case 'int32': return { type: 'int32', value };
  218. case 'int32_vec': return { type: 'int32[]', value: value.value };
  219. case 'uint32': return { type: 'uint32', value };
  220. case 'uint32_vec': return { type: 'uint32[]', value: value.value };
  221. case 'int64': return { type: 'int64', value };
  222. case 'uint64': return { type: 'uint64', value };
  223. case 'float': return { type: 'float32', value };
  224. case 'float_vec': return { type: 'float32[]', value: value.value };
  225. case 'double': return { type: 'float64', value };
  226. case 'double_vec': return { type: 'float64[]', value };
  227. case 'string': return { type: 'string', value };
  228. case 'string_vec': return { type: 'string[]', value: value.value };
  229. case 'bytes': return { type: 'byte[]', value: value.value };
  230. case 'map_string_2_int32': return { type: 'map<string,int32>', value: value.value };
  231. default: throw new xmodel.Error(`Unsupported attribute type '${type}'.`);
  232. }
  233. }
  234. };
  235. xmodel.Metadata = class {
  236. constructor(op_defs) {
  237. this._types = new Map();
  238. this._attributes = new Map();
  239. const categories = [
  240. ['avgpool2d', 'Pool'],
  241. ['batchnorm', 'Normalization'],
  242. ['celu', 'Activation'],
  243. ['concat-fix', 'Tensor'],
  244. ['concat', 'Tensor'],
  245. ['conv2d-fix', 'Layer'],
  246. ['conv2d', 'Layer'],
  247. ['depthwise-conv2d-fix', 'Layer'],
  248. ['depthwise-conv2d', 'Layer'],
  249. ['elu', 'Activation'],
  250. ['fix', 'Quantization'],
  251. ['fix2float', 'Quantization'],
  252. ['flatten', 'Shape'],
  253. ['float2fix', 'Quantization'],
  254. ['gelu', 'Activation'],
  255. ['hard-sigmoid', 'Activation'],
  256. ['hard-sigmoid-fix', 'Activation'],
  257. ['hard-swish', 'Activation'],
  258. ['hard-tanh', 'Activation'],
  259. ['identity', 'Control'],
  260. ['inner-product', 'Layer'],
  261. ['l2_normalize', 'Normalization'],
  262. ['leaky-relu', 'Activation'],
  263. ['leakyrelu', 'Activation'],
  264. ['maxpool2d', 'Pool'],
  265. ['pool-fix', 'Pool'],
  266. ['relu', 'Activation'],
  267. ['relu6', 'Activation'],
  268. ['reshape-fix', 'Shape'],
  269. ['reshape', 'Shape'],
  270. ['scale', 'Layer'],
  271. ['selu', 'Activation'],
  272. ['shape', 'Shape'],
  273. ['sigmoid', 'Activation'],
  274. ['softmax', 'Activation'],
  275. ['squeeze', 'Transform'],
  276. ['stack', 'Tensor'],
  277. ['strided_slice', 'Tensor'],
  278. ['swish', 'Activation'],
  279. ['tanh', 'Activation'],
  280. ['threshold', 'Quantization'],
  281. ['transpose', 'Tensor'],
  282. ['transposed-conv2d', 'Layer'],
  283. ['transposed-conv2d-fix', 'Layer'],
  284. ['transposed-depthwise-conv2d', 'Layer'],
  285. ['transposed-depthwise-conv2d-fix', 'Layer'],
  286. ['upsample-fix', 'Data'],
  287. ];
  288. this._types = new Map(categories.map(([name, category]) => [name, { name, category }]));
  289. for (const op_def of op_defs) {
  290. const type = this._types.get(op_def.name) || { name: op_def.name };
  291. if (op_def.annotation) {
  292. type.description = op_def.annotation;
  293. }
  294. type.inputs = op_def.input_args.map((input_arg) => {
  295. const input = {};
  296. input.name = input_arg.name;
  297. if (input_arg.annotation) {
  298. input.description = input_arg.annotation;
  299. }
  300. return input;
  301. });
  302. type.attributes = op_def.attrs.map((attr) => {
  303. const attribute = {};
  304. attribute.name = attr.name;
  305. attribute.default = xmodel.Utility.attribute(attr.default_value).value;
  306. if (attr.annotation) {
  307. attribute.description = attr.annotation;
  308. }
  309. return attribute;
  310. });
  311. for (const attribute of type.attributes) {
  312. this._attributes.set(`${type.name}:${attribute.name}`, attribute);
  313. }
  314. this._types.set(type.name, type);
  315. }
  316. }
  317. type(name) {
  318. if (!this._types.has(name)) {
  319. this._types.set(name, { name });
  320. }
  321. return this._types.get(name);
  322. }
  323. attribute(type, name) {
  324. const key = `${type}:${name}`;
  325. return this._attributes.get(key);
  326. }
  327. };
  328. xmodel.Error = class extends Error {
  329. constructor(message) {
  330. super(message);
  331. this.name = 'Error loading xmodel.';
  332. }
  333. };
  334. export const ModelFactory = xmodel.ModelFactory;