uff.js 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. const uff = {};
  2. uff.ModelFactory = class {
  3. async match(context) {
  4. const identifier = context.identifier;
  5. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  6. if (extension === 'uff' || extension === 'pb') {
  7. const tags = await context.tags('pb');
  8. if (tags.size > 0 &&
  9. tags.has(1) && tags.get(1) === 0 &&
  10. tags.has(2) && tags.get(2) === 0 &&
  11. tags.has(3) && tags.get(3) === 2 &&
  12. tags.has(4) && tags.get(4) === 2 &&
  13. (!tags.has(5) || tags.get(5) === 2)) {
  14. return context.set('uff.pb');
  15. }
  16. } else if (extension === 'pbtxt' || identifier.toLowerCase().endsWith('.uff.txt')) {
  17. const tags = await context.tags('pbtxt');
  18. if (tags.has('version') && tags.has('descriptors') && tags.has('graphs')) {
  19. return context.set('uff.pbtxt');
  20. }
  21. }
  22. return null;
  23. }
  24. async open(context) {
  25. uff.proto = await context.require('./uff-proto');
  26. uff.proto = uff.proto.uff;
  27. let meta_graph = null;
  28. switch (context.type) {
  29. case 'uff.pb': {
  30. try {
  31. const reader = await context.read('protobuf.binary');
  32. meta_graph = uff.proto.MetaGraph.decode(reader);
  33. } catch (error) {
  34. const message = error && error.message ? error.message : error.toString();
  35. throw new uff.Error(`File format is not uff.MetaGraph (${message.replace(/\.$/, '')}).`);
  36. }
  37. break;
  38. }
  39. case 'uff.pbtxt': {
  40. try {
  41. const reader = await context.read('protobuf.text');
  42. meta_graph = uff.proto.MetaGraph.decodeText(reader);
  43. } catch (error) {
  44. throw new uff.Error(`File text format is not uff.MetaGraph (${error.message}).`);
  45. }
  46. break;
  47. }
  48. default: {
  49. throw new uff.Error(`Unsupported UFF format '${context.type}'.`);
  50. }
  51. }
  52. const metadata = await context.metadata('uff-metadata.json');
  53. return new uff.Model(metadata, meta_graph);
  54. }
  55. };
  56. uff.Model = class {
  57. constructor(metadata, meta_graph) {
  58. const version = meta_graph.version;
  59. this.format = `UFF${version ? ` v${version}` : ''}`;
  60. this.imports = meta_graph.descriptors.map((descriptor) => `${descriptor.id} v${descriptor.version}`);
  61. const references = new Map(meta_graph.referenced_data.map((item) => [item.key, item.value]));
  62. for (const graph of meta_graph.graphs) {
  63. for (const node of graph.nodes) {
  64. for (const field of node.fields) {
  65. if (field.value.type === 'ref' && references.has(field.value.ref)) {
  66. field.value = references.get(field.value.ref);
  67. }
  68. }
  69. }
  70. }
  71. this.modules = meta_graph.graphs.map((graph) => new uff.Graph(metadata, graph));
  72. }
  73. };
  74. uff.Graph = class {
  75. constructor(metadata, graph) {
  76. this.name = graph.id;
  77. this.inputs = [];
  78. this.outputs = [];
  79. this.nodes = [];
  80. const values = new Map();
  81. const counts = new Map();
  82. for (const node of graph.nodes) {
  83. for (const input of node.inputs) {
  84. counts.set(input, counts.has(input) ? counts.get(input) + 1 : 1);
  85. values.set(input, new uff.Value(input));
  86. }
  87. if (!values.has(node.id)) {
  88. values.set(node.id, new uff.Value(node.id));
  89. }
  90. }
  91. values.map = (name) => {
  92. return values.get(name);
  93. };
  94. for (let i = graph.nodes.length - 1; i >= 0; i--) {
  95. const node = graph.nodes[i];
  96. if (node.operation === 'Const' && node.inputs.length === 0 && counts.get(node.id) === 1) {
  97. const fields = {};
  98. for (const field of node.fields) {
  99. fields[field.key] = field.value;
  100. }
  101. if (fields.dtype && fields.shape && fields.values) {
  102. const tensor = new uff.Tensor(fields.dtype.dtype, fields.shape, fields.values);
  103. values.set(node.id, new uff.Value(node.id, tensor.type, tensor));
  104. graph.nodes.splice(i, 1);
  105. }
  106. }
  107. if (node.operation === 'Input' && node.inputs.length === 0) {
  108. const fields = {};
  109. for (const field of node.fields) {
  110. fields[field.key] = field.value;
  111. }
  112. const type = fields.dtype && fields.shape ? new uff.TensorType(fields.dtype.dtype, fields.shape) : null;
  113. values.set(node.id, new uff.Value(node.id, type, null));
  114. }
  115. }
  116. for (const node of graph.nodes) {
  117. if (node.operation === 'Input') {
  118. this.inputs.push(new uff.Argument(node.id, [values.get(node.id)]));
  119. continue;
  120. }
  121. if (node.operation === 'MarkOutput' && node.inputs.length === 1) {
  122. this.outputs.push(new uff.Argument(node.id, [values.get(node.inputs[0])]));
  123. continue;
  124. }
  125. this.nodes.push(new uff.Node(metadata, node, values));
  126. }
  127. }
  128. };
  129. uff.Argument = class {
  130. constructor(name, value, type = null) {
  131. this.name = name;
  132. this.value = value;
  133. this.type = type;
  134. }
  135. };
  136. uff.Value = class {
  137. constructor(name, type = null, initializer = null) {
  138. if (typeof name !== 'string') {
  139. throw new uff.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  140. }
  141. this.name = name;
  142. this.type = type;
  143. this.initializer = initializer;
  144. }
  145. };
  146. uff.Node = class {
  147. constructor(metadata, node, values) {
  148. this.name = node.id;
  149. this.type = metadata.type(node.operation) || { name: node.operation };
  150. this.attributes = [];
  151. this.inputs = [];
  152. this.outputs = [];
  153. if (node.inputs && node.inputs.length > 0) {
  154. let index = 0;
  155. if (this.type && this.type.inputs) {
  156. for (const metadata of this.type.inputs) {
  157. if (index < node.inputs.length || metadata.optional !== true) {
  158. const count = metadata.list ? (node.inputs.length - index) : 1;
  159. const inputs = node.inputs.slice(index, index + count);
  160. const argument = new uff.Argument(metadata.name, inputs.map((name) => values.map(name)));
  161. this.inputs.push(argument);
  162. index += count;
  163. }
  164. }
  165. }
  166. this.inputs.push(...node.inputs.slice(index).map((identifier, i) => {
  167. const name = ((index + i) === 0) ? 'input' : (index + i).toString();
  168. return new uff.Argument(name, [values.map(identifier)]);
  169. }));
  170. }
  171. this.outputs.push(new uff.Argument('output', [values.map(node.id)]));
  172. for (const field of node.fields) {
  173. let type = null;
  174. let value = null;
  175. switch (field.value.type) {
  176. case 's': value = field.value.s; type = 'string'; break;
  177. case 's_list': value = field.value.s_list; type = 'string[]'; break;
  178. case 'd': value = field.value.d; type = 'float64'; break;
  179. case 'd_list': value = field.value.d_list.val; type = 'float64[]'; break;
  180. case 'b': value = field.value.b; type = 'boolean'; break;
  181. case 'b_list': value = field.value.b_list; type = 'boolean[]'; break;
  182. case 'i': value = field.value.i; type = 'int64'; break;
  183. case 'i_list': value = field.value.i_list.val; type = 'int64[]'; break;
  184. case 'blob': value = field.value.blob; break;
  185. case 'ref': value = field.value.ref; type = 'ref'; break;
  186. case 'dtype': value = new uff.TensorType(field.value.dtype, null).dataType; type = 'uff.DataType'; break;
  187. case 'dtype_list': value = field.value.dtype_list.map((type) => new uff.TensorType(type, null).dataType); type = 'uff.DataType[]'; break;
  188. case 'dim_orders': value = field.value.dim_orders; break;
  189. case 'dim_orders_list': value = field.value.dim_orders_list.val; break;
  190. default: throw new uff.Error(`Unsupported attribute '${field.key}' value '${JSON.stringify(value)}'.`);
  191. }
  192. const attribute = new uff.Argument(field.key, value, type);
  193. this.attributes.push(attribute);
  194. }
  195. }
  196. };
  197. uff.Tensor = class {
  198. constructor(dataType, shape, values) {
  199. this.type = new uff.TensorType(dataType, shape);
  200. switch (values.type) {
  201. case 'blob': this.values = values.blob; break;
  202. default: throw new uff.Error(`Unsupported values format '${JSON.stringify(values.type)}'.`);
  203. }
  204. if (this.values.length > 8 &&
  205. this.values[0] === 0x28 && this.values[1] === 0x2e && this.values[2] === 0x2e && this.values[3] === 0x2e &&
  206. this.values[this.values.length - 1] === 0x29 && this.values[this.values.length - 2] === 0x2e && this.values[this.values.length - 3] === 0x2e && this.values[this.values.length - 4] === 0x2e) {
  207. this.values = null;
  208. }
  209. }
  210. };
  211. uff.TensorType = class {
  212. constructor(dataType, shape) {
  213. switch (dataType) {
  214. case uff.proto.DataType.DT_INT8: this.dataType = 'int8'; break;
  215. case uff.proto.DataType.DT_INT16: this.dataType = 'int16'; break;
  216. case uff.proto.DataType.DT_INT32: this.dataType = 'int32'; break;
  217. case uff.proto.DataType.DT_INT64: this.dataType = 'int64'; break;
  218. case uff.proto.DataType.DT_FLOAT16: this.dataType = 'float16'; break;
  219. case uff.proto.DataType.DT_FLOAT32: this.dataType = 'float32'; break;
  220. case 7: this.dataType = '?'; break;
  221. default: throw new uff.Error(`Unsupported data type '${JSON.stringify(dataType)}'.`);
  222. }
  223. this.shape = shape ? new uff.TensorShape(shape) : null;
  224. }
  225. toString() {
  226. return this.dataType + this.shape.toString();
  227. }
  228. };
  229. uff.TensorShape = class {
  230. constructor(shape) {
  231. if (shape.type !== 'i_list') {
  232. throw new uff.Error(`Unsupported shape format '${JSON.stringify(shape.type)}'.`);
  233. }
  234. const dimensions = shape.i_list.val;
  235. this.dimensions = dimensions.map((dim) => typeof dim === 'bigint' ? dim.toNumber() : dim);
  236. }
  237. toString() {
  238. if (Array.isArray(this.dimensions) && this.dimensions.length > 0) {
  239. return `[${this.dimensions.join(',')}]`;
  240. }
  241. return '';
  242. }
  243. };
  244. uff.Error = class extends Error {
  245. constructor(message) {
  246. super(message);
  247. this.name = 'Error loading UFF model.';
  248. }
  249. };
  250. export const ModelFactory = uff.ModelFactory;