uff.js 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. var uff = {};
  2. var protobuf = require('./protobuf');
  3. uff.ModelFactory = class {
  4. match(context) {
  5. const identifier = context.identifier;
  6. const extension = identifier.split('.').pop().toLowerCase();
  7. if (extension === 'uff' || extension === 'pb') {
  8. const tags = context.tags('pb');
  9. if (tags.size > 0 &&
  10. tags.has(1) && tags.get(1) === 0 &&
  11. tags.has(2) && tags.get(2) === 0 &&
  12. tags.has(3) && tags.get(3) === 2 &&
  13. tags.has(4) && tags.get(4) === 2 &&
  14. (!tags.has(5) || tags.get(5) === 2)) {
  15. return 'uff.pb';
  16. }
  17. }
  18. if (extension === 'pbtxt' || identifier.toLowerCase().endsWith('.uff.txt')) {
  19. const tags = context.tags('pbtxt');
  20. if (tags.has('version') && tags.has('descriptors') && tags.has('graphs')) {
  21. return 'uff.pbtxt';
  22. }
  23. }
  24. return undefined;
  25. }
  26. async open(context, target) {
  27. await context.require('./uff-proto');
  28. uff.proto = protobuf.get('uff').uff;
  29. let meta_graph = null;
  30. switch (target) {
  31. case 'uff.pb': {
  32. try {
  33. const stream = context.stream;
  34. const reader = protobuf.BinaryReader.open(stream);
  35. meta_graph = uff.proto.MetaGraph.decode(reader);
  36. } catch (error) {
  37. const message = error && error.message ? error.message : error.toString();
  38. throw new uff.Error('File format is not uff.MetaGraph (' + message.replace(/\.$/, '') + ').');
  39. }
  40. break;
  41. }
  42. case 'uff.pbtxt': {
  43. try {
  44. const stream = context.stream;
  45. const reader = protobuf.TextReader.open(stream);
  46. meta_graph = uff.proto.MetaGraph.decodeText(reader);
  47. } catch (error) {
  48. throw new uff.Error('File text format is not uff.MetaGraph (' + error.message + ').');
  49. }
  50. break;
  51. }
  52. default: {
  53. throw new uff.Error("Unsupported UFF format '" + target + "'.");
  54. }
  55. }
  56. const metadata = await context.metadata('uff-metadata.json');
  57. return new uff.Model(metadata, meta_graph);
  58. }
  59. };
  60. uff.Model = class {
  61. constructor(metadata, meta_graph) {
  62. this._version = meta_graph.version;
  63. this._imports = meta_graph.descriptors.map((descriptor) => descriptor.id + ' v' + descriptor.version.toString());
  64. const references = new Map(meta_graph.referenced_data.map((item) => [ item.key, item.value ]));
  65. for (const graph of meta_graph.graphs) {
  66. for (const node of graph.nodes) {
  67. for (const field of node.fields) {
  68. if (field.value.type === 'ref' && references.has(field.value.ref)) {
  69. field.value = references.get(field.value.ref);
  70. }
  71. }
  72. }
  73. }
  74. this._graphs = meta_graph.graphs.map((graph) => new uff.Graph(metadata, graph));
  75. }
  76. get format() {
  77. return 'UFF' + (this._version ? ' v' + this._version.toString() : '');
  78. }
  79. get imports() {
  80. return this._imports;
  81. }
  82. get graphs() {
  83. return this._graphs;
  84. }
  85. };
  86. uff.Graph = class {
  87. constructor(metadata, graph) {
  88. this._name = graph.id;
  89. this._inputs = [];
  90. this._outputs = [];
  91. this._nodes = [];
  92. const args = new Map();
  93. const counts = new Map();
  94. for (const node of graph.nodes) {
  95. for (const input of node.inputs) {
  96. counts.set(input, counts.has(input) ? counts.get(input) + 1 : 1);
  97. args.set(input, new uff.Value(input));
  98. }
  99. if (!args.has(node.id)) {
  100. args.set(node.id, new uff.Value(node.id));
  101. }
  102. }
  103. for (let i = graph.nodes.length - 1; i >= 0; i--) {
  104. const node = graph.nodes[i];
  105. if (node.operation === 'Const' && node.inputs.length === 0 && counts.get(node.id) === 1) {
  106. const fields = {};
  107. for (const field of node.fields) {
  108. fields[field.key] = field.value;
  109. }
  110. if (fields.dtype && fields.shape && fields.values) {
  111. const tensor = new uff.Tensor(fields.dtype.dtype, fields.shape, fields.values);
  112. args.set(node.id, new uff.Value(node.id, tensor.type, tensor));
  113. graph.nodes.splice(i, 1);
  114. }
  115. }
  116. if (node.operation === 'Input' && node.inputs.length === 0) {
  117. const fields = {};
  118. for (const field of node.fields) {
  119. fields[field.key] = field.value;
  120. }
  121. const type = fields.dtype && fields.shape ? new uff.TensorType(fields.dtype.dtype, fields.shape) : null;
  122. args.set(node.id, new uff.Value(node.id, type, null));
  123. }
  124. }
  125. for (const node of graph.nodes) {
  126. if (node.operation === 'Input') {
  127. this._inputs.push(new uff.Argument(node.id, [ args.get(node.id) ]));
  128. continue;
  129. }
  130. if (node.operation === 'MarkOutput' && node.inputs.length === 1) {
  131. this._outputs.push(new uff.Argument(node.id, [ args.get(node.inputs[0]) ]));
  132. continue;
  133. }
  134. this._nodes.push(new uff.Node(metadata, node, args));
  135. }
  136. }
  137. get name() {
  138. return this._name;
  139. }
  140. get inputs() {
  141. return this._inputs;
  142. }
  143. get outputs() {
  144. return this._outputs;
  145. }
  146. get nodes() {
  147. return this._nodes;
  148. }
  149. };
  150. uff.Argument = class {
  151. constructor(name, value) {
  152. this._name = name;
  153. this._value = value;
  154. }
  155. get name() {
  156. return this._name;
  157. }
  158. get value() {
  159. return this._value;
  160. }
  161. };
  162. uff.Value = class {
  163. constructor(name, type, initializer) {
  164. if (typeof name !== 'string') {
  165. throw new uff.Error("Invalid value identifier '" + JSON.stringify(name) + "'.");
  166. }
  167. this._name = name;
  168. this._type = type || null;
  169. this._initializer = initializer || null;
  170. }
  171. get name() {
  172. return this._name;
  173. }
  174. get type() {
  175. return this._type;
  176. }
  177. get initializer() {
  178. return this._initializer;
  179. }
  180. };
  181. uff.Node = class {
  182. constructor(metadata, node, args) {
  183. this._name = node.id;
  184. this._type = metadata.type(node.operation) || { name: node.operation };
  185. this._attributes = [];
  186. this._inputs = [];
  187. this._outputs = [];
  188. if (node.inputs && node.inputs.length > 0) {
  189. let inputIndex = 0;
  190. if (this._type && this._type.inputs) {
  191. for (const inputSchema of this._type.inputs) {
  192. if (inputIndex < node.inputs.length || inputSchema.optional !== true) {
  193. const inputCount = inputSchema.list ? (node.inputs.length - inputIndex) : 1;
  194. const inputArguments = node.inputs.slice(inputIndex, inputIndex + inputCount).map((id) => {
  195. return args.get(id);
  196. });
  197. inputIndex += inputCount;
  198. this._inputs.push(new uff.Argument(inputSchema.name, inputArguments));
  199. }
  200. }
  201. }
  202. this._inputs.push(...node.inputs.slice(inputIndex).map((id, index) => {
  203. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  204. return new uff.Argument(inputName, [ args.get(id) ]);
  205. }));
  206. }
  207. this._outputs.push(new uff.Argument('output', [
  208. args.get(node.id)
  209. ]));
  210. for (const field of node.fields) {
  211. this._attributes.push(new uff.Attribute(metadata.attribute(node.operation, field.key), field.key, field.value));
  212. }
  213. }
  214. get name() {
  215. return this._name;
  216. }
  217. get type() {
  218. return this._type;
  219. }
  220. get inputs() {
  221. return this._inputs;
  222. }
  223. get outputs() {
  224. return this._outputs;
  225. }
  226. get attributes() {
  227. return this._attributes;
  228. }
  229. };
  230. uff.Attribute = class {
  231. constructor(metadata, name, value) {
  232. this._name = name;
  233. switch (value.type) {
  234. case 's': this._value = value.s; this._type = 'string'; break;
  235. case 's_list': this._value = value.s_list; this._type = 'string[]'; break;
  236. case 'd': this._value = value.d; this._type = 'float64'; break;
  237. case 'd_list': this._value = value.d_list.val; this._type = 'float64[]'; break;
  238. case 'b': this._value = value.b; this._type = 'boolean'; break;
  239. case 'b_list': this._value = value.b_list; this._type = 'boolean[]'; break;
  240. case 'i': this._value = value.i; this._type = 'int64'; break;
  241. case 'i_list': this._value = value.i_list.val; this._type = 'int64[]'; break;
  242. case 'blob': this._value = value.blob; break;
  243. case 'ref': this._value = value.ref; this._type = 'ref'; break;
  244. case 'dtype': this._value = new uff.TensorType(value.dtype, null).dataType; this._type = 'uff.DataType'; break;
  245. case 'dtype_list': this._value = value.dtype_list.map((type) => new uff.TensorType(type, null).dataType); this._type = 'uff.DataType[]'; break;
  246. case 'dim_orders': this._value = value.dim_orders; break;
  247. case 'dim_orders_list': this._value = value.dim_orders_list.val; break;
  248. default: throw new uff.Error("Unsupported attribute '" + name + "' value '" + JSON.stringify(value) + "'.");
  249. }
  250. }
  251. get type() {
  252. return this._type;
  253. }
  254. get name() {
  255. return this._name;
  256. }
  257. get value() {
  258. return this._value;
  259. }
  260. get visible() {
  261. return true;
  262. }
  263. };
  264. uff.Tensor = class {
  265. constructor(dataType, shape, values) {
  266. this._type = new uff.TensorType(dataType, shape);
  267. switch (values.type) {
  268. case 'blob': this._data = values.blob; break;
  269. default: throw new uff.Error("Unsupported values format '" + JSON.stringify(values.type) + "'.");
  270. }
  271. if (this._data.length > 8 &&
  272. this._data[0] === 0x28 && this._data[1] === 0x2e && this._data[2] === 0x2e && this._data[3] === 0x2e &&
  273. this._data[this._data.length - 1] === 0x29 && this._data[this._data.length - 2] === 0x2e && this._data[this._data.length - 3] === 0x2e && this._data[this._data.length - 4] === 0x2e) {
  274. this._data = null;
  275. }
  276. }
  277. get type() {
  278. return this._type;
  279. }
  280. get values() {
  281. return this._data;
  282. }
  283. };
  284. uff.TensorType = class {
  285. constructor(dataType, shape) {
  286. switch (dataType) {
  287. case uff.proto.DataType.DT_INT8: this._dataType = 'int8'; break;
  288. case uff.proto.DataType.DT_INT16: this._dataType = 'int16'; break;
  289. case uff.proto.DataType.DT_INT32: this._dataType = 'int32'; break;
  290. case uff.proto.DataType.DT_INT64: this._dataType = 'int64'; break;
  291. case uff.proto.DataType.DT_FLOAT16: this._dataType = 'float16'; break;
  292. case uff.proto.DataType.DT_FLOAT32: this._dataType = 'float32'; break;
  293. case 7: this._dataType = '?'; break;
  294. default: throw new uff.Error("Unsupported data type '" + JSON.stringify(dataType) + "'.");
  295. }
  296. this._shape = shape ? new uff.TensorShape(shape) : null;
  297. }
  298. get dataType() {
  299. return this._dataType;
  300. }
  301. get shape() {
  302. return this._shape;
  303. }
  304. toString() {
  305. return this.dataType + this._shape.toString();
  306. }
  307. };
  308. uff.TensorShape = class {
  309. constructor(shape) {
  310. if (shape.type !== 'i_list') {
  311. throw new uff.Error("Unsupported shape format '" + JSON.stringify(shape.type) + "'.");
  312. }
  313. this._dimensions = shape.i_list.val;
  314. }
  315. get dimensions() {
  316. return this._dimensions;
  317. }
  318. toString() {
  319. if (!this._dimensions || this._dimensions.length == 0) {
  320. return '';
  321. }
  322. return '[' + this._dimensions.join(',') + ']';
  323. }
  324. };
  325. uff.Error = class extends Error {
  326. constructor(message) {
  327. super(message);
  328. this.name = 'Error loading UFF model.';
  329. }
  330. };
  331. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  332. module.exports.ModelFactory = uff.ModelFactory;
  333. }