uff.js 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. // Experimental
  2. var uff = uff || {};
  3. var protobuf = protobuf || require('./protobuf');
  4. uff.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier;
  7. const extension = identifier.split('.').pop().toLowerCase();
  8. if (extension === 'uff' || extension === 'pb') {
  9. const tags = context.tags('pb');
  10. if (tags.size > 0 &&
  11. tags.has(1) && tags.get(1) === 0 &&
  12. tags.has(2) && tags.get(2) === 0 &&
  13. tags.has(3) && tags.get(3) === 2 &&
  14. tags.has(4) && tags.get(4) === 2 &&
  15. (!tags.has(5) || tags.get(5) === 2)) {
  16. return 'uff.pb';
  17. }
  18. }
  19. if (extension === 'pbtxt' || identifier.toLowerCase().endsWith('.uff.txt')) {
  20. const tags = context.tags('pbtxt');
  21. if (tags.has('version') && tags.has('descriptors') && tags.has('graphs')) {
  22. return 'uff.pbtxt';
  23. }
  24. }
  25. return undefined;
  26. }
  27. open(context, match) {
  28. return context.require('./uff-proto').then(() => {
  29. uff.proto = protobuf.get('uff').uff;
  30. let meta_graph = null;
  31. switch (match) {
  32. case 'uff.pb': {
  33. try {
  34. const stream = context.stream;
  35. const reader = protobuf.BinaryReader.open(stream);
  36. meta_graph = uff.proto.MetaGraph.decode(reader);
  37. }
  38. catch (error) {
  39. const message = error && error.message ? error.message : error.toString();
  40. throw new uff.Error('File format is not uff.MetaGraph (' + message.replace(/\.$/, '') + ').');
  41. }
  42. break;
  43. }
  44. case 'uff.pbtxt': {
  45. try {
  46. const stream = context.stream;
  47. const reader = protobuf.TextReader.open(stream);
  48. meta_graph = uff.proto.MetaGraph.decodeText(reader);
  49. }
  50. catch (error) {
  51. throw new uff.Error('File text format is not uff.MetaGraph (' + error.message + ').');
  52. }
  53. break;
  54. }
  55. default: {
  56. throw new uff.Error("Unsupported UFF format '" + match + "'.");
  57. }
  58. }
  59. return context.metadata('uff-metadata.json').then((metadata) => {
  60. return new uff.Model(metadata, meta_graph);
  61. });
  62. });
  63. }
  64. };
  65. uff.Model = class {
  66. constructor(metadata, meta_graph) {
  67. this._version = meta_graph.version;
  68. this._imports = meta_graph.descriptors.map((descriptor) => descriptor.id + ' v' + descriptor.version.toString());
  69. const references = new Map(meta_graph.referenced_data.map((item) => [ item.key, item.value ]));
  70. for (const graph of meta_graph.graphs) {
  71. for (const node of graph.nodes) {
  72. for (const field of node.fields) {
  73. if (field.value.type === 'ref' && references.has(field.value.ref)) {
  74. field.value = references.get(field.value.ref);
  75. }
  76. }
  77. }
  78. }
  79. this._graphs = meta_graph.graphs.map((graph) => new uff.Graph(metadata, graph));
  80. }
  81. get format() {
  82. return 'UFF' + (this._version ? ' v' + this._version.toString() : '');
  83. }
  84. get imports() {
  85. return this._imports;
  86. }
  87. get graphs() {
  88. return this._graphs;
  89. }
  90. };
  91. uff.Graph = class {
  92. constructor(metadata, graph) {
  93. this._name = graph.id;
  94. this._inputs = [];
  95. this._outputs = [];
  96. this._nodes = [];
  97. const args = new Map();
  98. const counts = new Map();
  99. for (const node of graph.nodes) {
  100. for (const input of node.inputs) {
  101. counts.set(input, counts.has(input) ? counts.get(input) + 1 : 1);
  102. args.set(input, new uff.Argument(input));
  103. }
  104. if (!args.has(node.id)) {
  105. args.set(node.id, new uff.Argument(node.id));
  106. }
  107. }
  108. for (let i = graph.nodes.length - 1; i >= 0; i--) {
  109. const node = graph.nodes[i];
  110. if (node.operation === 'Const' && node.inputs.length === 0 && counts.get(node.id) === 1) {
  111. const fields = {};
  112. for (const field of node.fields) {
  113. fields[field.key] = field.value;
  114. }
  115. if (fields.dtype && fields.shape && fields.values) {
  116. const tensor = new uff.Tensor(fields.dtype.dtype, fields.shape, fields.values);
  117. args.set(node.id, new uff.Argument(node.id, tensor.type, tensor));
  118. graph.nodes.splice(i, 1);
  119. }
  120. }
  121. if (node.operation === 'Input' && node.inputs.length === 0) {
  122. const fields = {};
  123. for (const field of node.fields) {
  124. fields[field.key] = field.value;
  125. }
  126. const type = fields.dtype && fields.shape ? new uff.TensorType(fields.dtype.dtype, fields.shape) : null;
  127. args.set(node.id, new uff.Argument(node.id, type, null));
  128. }
  129. }
  130. for (const node of graph.nodes) {
  131. if (node.operation === 'Input') {
  132. this._inputs.push(new uff.Parameter(node.id, [ args.get(node.id) ]));
  133. continue;
  134. }
  135. if (node.operation === 'MarkOutput' && node.inputs.length === 1) {
  136. this._outputs.push(new uff.Parameter(node.id, [ args.get(node.inputs[0]) ]));
  137. continue;
  138. }
  139. this._nodes.push(new uff.Node(metadata, node, args));
  140. }
  141. }
  142. get name() {
  143. return this._name;
  144. }
  145. get inputs() {
  146. return this._inputs;
  147. }
  148. get outputs() {
  149. return this._outputs;
  150. }
  151. get nodes() {
  152. return this._nodes;
  153. }
  154. };
  155. uff.Parameter = class {
  156. constructor(name, args) {
  157. this._name = name;
  158. this._arguments = args;
  159. }
  160. get name() {
  161. return this._name;
  162. }
  163. get visible() {
  164. return true;
  165. }
  166. get arguments() {
  167. return this._arguments;
  168. }
  169. };
  170. uff.Argument = class {
  171. constructor(name, type, initializer) {
  172. if (typeof name !== 'string') {
  173. throw new uff.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  174. }
  175. this._name = name;
  176. this._type = type || null;
  177. this._initializer = initializer || null;
  178. }
  179. get name() {
  180. return this._name;
  181. }
  182. get type() {
  183. return this._type;
  184. }
  185. get initializer() {
  186. return this._initializer;
  187. }
  188. };
  189. uff.Node = class {
  190. constructor(metadata, node, args) {
  191. this._name = node.id;
  192. this._type = metadata.type(node.operation) || { name: node.operation };
  193. this._attributes = [];
  194. this._inputs = [];
  195. this._outputs = [];
  196. if (node.inputs && node.inputs.length > 0) {
  197. let inputIndex = 0;
  198. if (this._type && this._type.inputs) {
  199. for (const inputSchema of this._type.inputs) {
  200. if (inputIndex < node.inputs.length || inputSchema.optional !== true) {
  201. const inputCount = inputSchema.list ? (node.inputs.length - inputIndex) : 1;
  202. const inputArguments = node.inputs.slice(inputIndex, inputIndex + inputCount).map((id) => {
  203. return args.get(id);
  204. });
  205. inputIndex += inputCount;
  206. this._inputs.push(new uff.Parameter(inputSchema.name, inputArguments));
  207. }
  208. }
  209. }
  210. this._inputs.push(...node.inputs.slice(inputIndex).map((id, index) => {
  211. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  212. return new uff.Parameter(inputName, [ args.get(id) ]);
  213. }));
  214. }
  215. this._outputs.push(new uff.Parameter('output', [
  216. args.get(node.id)
  217. ]));
  218. for (const field of node.fields) {
  219. this._attributes.push(new uff.Attribute(metadata.attribute(node.operation, field.key), field.key, field.value));
  220. }
  221. }
  222. get name() {
  223. return this._name;
  224. }
  225. get type() {
  226. return this._type;
  227. }
  228. get inputs() {
  229. return this._inputs;
  230. }
  231. get outputs() {
  232. return this._outputs;
  233. }
  234. get attributes() {
  235. return this._attributes;
  236. }
  237. };
  238. uff.Attribute = class {
  239. constructor(metadata, name, value) {
  240. this._name = name;
  241. switch(value.type) {
  242. case 's': this._value = value.s; this._type = 'string'; break;
  243. case 's_list': this._value = value.s_list; this._type = 'string[]'; break;
  244. case 'd': this._value = value.d; this._type = 'float64'; break;
  245. case 'd_list': this._value = value.d_list.val; this._type = 'float64[]'; break;
  246. case 'b': this._value = value.b; this._type = 'boolean'; break;
  247. case 'b_list': this._value = value.b_list; this._type = 'boolean[]'; break;
  248. case 'i': this._value = value.i; this._type = 'int64'; break;
  249. case 'i_list': this._value = value.i_list.val; this._type = 'int64[]'; break;
  250. case 'blob': this._value = value.blob; break;
  251. case 'ref': this._value = value.ref; this._type = 'ref'; break;
  252. case 'dtype': this._value = new uff.TensorType(value.dtype, null).dataType; this._type = 'uff.DataType'; break;
  253. case 'dtype_list': this._value = value.dtype_list.map((type) => new uff.TensorType(type, null).dataType); this._type = 'uff.DataType[]'; break;
  254. case 'dim_orders': this._value = value.dim_orders; break;
  255. case 'dim_orders_list': this._value = value.dim_orders_list.val; break;
  256. default: throw new uff.Error("Unsupported attribute '" + name + "' value '" + JSON.stringify(value) + "'.");
  257. }
  258. }
  259. get type() {
  260. return this._type;
  261. }
  262. get name() {
  263. return this._name;
  264. }
  265. get value() {
  266. return this._value;
  267. }
  268. get visible() {
  269. return true;
  270. }
  271. };
  272. uff.Tensor = class {
  273. constructor(dataType, shape, values) {
  274. this._type = new uff.TensorType(dataType, shape);
  275. switch (values.type) {
  276. case 'blob': this._data = values.blob; break;
  277. default: throw new uff.Error("Unsupported values format '" + JSON.stringify(values.type) + "'.");
  278. }
  279. }
  280. get kind() {
  281. return 'Const';
  282. }
  283. get type() {
  284. return this._type;
  285. }
  286. get state() {
  287. return this._context().state;
  288. }
  289. get value() {
  290. const context = this._context();
  291. if (context.state) {
  292. return null;
  293. }
  294. context.limit = Number.MAX_SAFE_INTEGER;
  295. return this._decode(context, 0);
  296. }
  297. toString() {
  298. const context = this._context();
  299. if (context.state) {
  300. return '';
  301. }
  302. context.limit = 10000;
  303. const value = this._decode(context, 0);
  304. return JSON.stringify(value, null, 4);
  305. }
  306. _context() {
  307. const context = {};
  308. context.state = null;
  309. context.index = 0;
  310. context.count = 0;
  311. if (this._data == null) {
  312. context.state = 'Tensor data is empty.';
  313. return context;
  314. }
  315. if (this._data.length > 8 &&
  316. this._data[0] === 0x28 && this._data[1] === 0x2e && this._data[2] === 0x2e && this._data[3] === 0x2e &&
  317. this._data[this._data.length - 1] === 0x29 && this._data[this._data.length - 2] === 0x2e && this._data[this._data.length - 3] === 0x2e && this._data[this._data.length - 4] === 0x2e) {
  318. context.state = 'Tensor data is empty.';
  319. return context;
  320. }
  321. if (this._type.dataType === '?') {
  322. context.state = 'Tensor data type is unknown.';
  323. return context;
  324. }
  325. context.dataType = this._type.dataType;
  326. context.shape = this._type.shape.dimensions;
  327. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  328. return context;
  329. }
  330. _decode(context, dimension) {
  331. const shape = (context.shape.length == 0) ? [ 1 ] : context.shape;
  332. const size = shape[dimension];
  333. const results = [];
  334. if (dimension == shape.length - 1) {
  335. for (let i = 0; i < size; i++) {
  336. if (context.count > context.limit) {
  337. results.push('...');
  338. return results;
  339. }
  340. switch (context.dataType) {
  341. case 'int8':
  342. results.push(context.data.getInt8(context.index));
  343. context.index += 1;
  344. context.count++;
  345. break;
  346. case 'int16':
  347. results.push(context.data.getInt16(context.index));
  348. context.index += 2;
  349. context.count++;
  350. break;
  351. case 'int32':
  352. results.push(context.data.getInt32(context.index, true));
  353. context.index += 4;
  354. context.count++;
  355. break;
  356. case 'int64':
  357. results.push(context.data.getInt64(context.index, true));
  358. context.index += 8;
  359. context.count++;
  360. break;
  361. case 'float16':
  362. results.push(context.data.getFloat16(context.index, true));
  363. context.index += 2;
  364. context.count++;
  365. break;
  366. case 'float32':
  367. results.push(context.data.getFloat32(context.index, true));
  368. context.index += 4;
  369. context.count++;
  370. break;
  371. default:
  372. break;
  373. }
  374. }
  375. }
  376. else {
  377. for (let j = 0; j < size; j++) {
  378. if (context.count > context.limit) {
  379. results.push('...');
  380. return results;
  381. }
  382. results.push(this._decode(context, dimension + 1));
  383. }
  384. }
  385. if (context.shape.length == 0) {
  386. return results[0];
  387. }
  388. return results;
  389. }
  390. };
  391. uff.TensorType = class {
  392. constructor(dataType, shape) {
  393. switch (dataType) {
  394. case uff.proto.DataType.DT_INT8: this._dataType = 'int8'; break;
  395. case uff.proto.DataType.DT_INT16: this._dataType = 'int16'; break;
  396. case uff.proto.DataType.DT_INT32: this._dataType = 'int32'; break;
  397. case uff.proto.DataType.DT_INT64: this._dataType = 'int64'; break;
  398. case uff.proto.DataType.DT_FLOAT16: this._dataType = 'float16'; break;
  399. case uff.proto.DataType.DT_FLOAT32: this._dataType = 'float32'; break;
  400. case 7: this._dataType = '?'; break;
  401. default: throw new uff.Error("Unsupported data type '" + JSON.stringify(dataType) + "'.");
  402. }
  403. this._shape = shape ? new uff.TensorShape(shape) : null;
  404. }
  405. get dataType() {
  406. return this._dataType;
  407. }
  408. get shape() {
  409. return this._shape;
  410. }
  411. toString() {
  412. return this.dataType + this._shape.toString();
  413. }
  414. };
  415. uff.TensorShape = class {
  416. constructor(shape) {
  417. if (shape.type !== 'i_list') {
  418. throw new uff.Error("Unsupported shape format '" + JSON.stringify(shape.type) + "'.");
  419. }
  420. this._dimensions = shape.i_list.val;
  421. }
  422. get dimensions() {
  423. return this._dimensions;
  424. }
  425. toString() {
  426. if (!this._dimensions || this._dimensions.length == 0) {
  427. return '';
  428. }
  429. return '[' + this._dimensions.join(',') + ']';
  430. }
  431. };
  432. uff.Error = class extends Error {
  433. constructor(message) {
  434. super(message);
  435. this.name = 'Error loading UFF model.';
  436. }
  437. };
  438. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  439. module.exports.ModelFactory = uff.ModelFactory;
  440. }