uff.js 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var uff = uff || {};
  4. var protobuf = protobuf || require('./protobuf');
  5. uff.ModelFactory = class {
  6. match(context) {
  7. const identifier = context.identifier;
  8. const extension = identifier.split('.').pop().toLowerCase();
  9. if (extension === 'uff' || extension === 'pb') {
  10. const tags = context.tags('pb');
  11. if (tags.size > 0 &&
  12. tags.has(1) && tags.get(1) === 0 &&
  13. tags.has(2) && tags.get(2) === 0 &&
  14. tags.has(3) && tags.get(3) === 2 &&
  15. tags.has(4) && tags.get(4) === 2 &&
  16. tags.has(5) && tags.get(5) === 2) {
  17. return true;
  18. }
  19. }
  20. if (extension === 'pbtxt' || identifier.toLowerCase().endsWith('.uff.txt')) {
  21. const tags = context.tags('pbtxt');
  22. if (tags.has('version') && tags.has('descriptors') && tags.has('graphs')) {
  23. return true;
  24. }
  25. }
  26. return false;
  27. }
  28. open(context) {
  29. return context.require('./uff-proto').then(() => {
  30. let meta_graph = null;
  31. const identifier = context.identifier;
  32. const extension = identifier.split('.').pop().toLowerCase();
  33. if (extension === 'pbtxt' || identifier.toLowerCase().endsWith('.uff.txt')) {
  34. try {
  35. uff.proto = protobuf.get('uff').uff;
  36. const stream = context.stream;
  37. const reader = protobuf.TextReader.open(stream);
  38. meta_graph = uff.proto.MetaGraph.decodeText(reader);
  39. }
  40. catch (error) {
  41. throw new uff.Error('File text format is not uff.MetaGraph (' + error.message + ').');
  42. }
  43. }
  44. else {
  45. try {
  46. uff.proto = protobuf.get('uff').uff;
  47. const stream = context.stream;
  48. const reader = protobuf.BinaryReader.open(stream);
  49. meta_graph = uff.proto.MetaGraph.decode(reader);
  50. }
  51. catch (error) {
  52. const message = error && error.message ? error.message : error.toString();
  53. throw new uff.Error('File format is not uff.MetaGraph (' + message.replace(/\.$/, '') + ').');
  54. }
  55. }
  56. return uff.Metadata.open(context).then((metadata) => {
  57. return new uff.Model(metadata, meta_graph);
  58. });
  59. });
  60. }
  61. };
  62. uff.Model = class {
  63. constructor(metadata, meta_graph) {
  64. this._version = meta_graph.version;
  65. this._imports = meta_graph.descriptors.map((descriptor) => descriptor.id + ' v' + descriptor.version.toString());
  66. const references = new Map(meta_graph.referenced_data.map((item) => [ item.key, item.value ]));
  67. for (const graph of meta_graph.graphs) {
  68. for (const node of graph.nodes) {
  69. for (const field of node.fields) {
  70. if (field.value.type === 'ref' && references.has(field.value.ref)) {
  71. field.value = references.get(field.value.ref);
  72. }
  73. }
  74. }
  75. }
  76. this._graphs = meta_graph.graphs.map((graph) => new uff.Graph(metadata, graph));
  77. }
  78. get format() {
  79. return 'UFF' + (this._version ? ' v' + this._version.toString() : '');
  80. }
  81. get imports() {
  82. return this._imports;
  83. }
  84. get graphs() {
  85. return this._graphs;
  86. }
  87. };
  88. uff.Graph = class {
  89. constructor(metadata, graph) {
  90. this._name = graph.id;
  91. this._inputs = [];
  92. this._outputs = [];
  93. this._nodes = [];
  94. const args = new Map();
  95. const inputCountMap = new Map();
  96. for (const node of graph.nodes) {
  97. for (const input of node.inputs) {
  98. inputCountMap.set(input, inputCountMap.has(input) ? inputCountMap.get(input) + 1 : 1);
  99. args.set(input, new uff.Argument(input));
  100. }
  101. if (!args.has(node.id)) {
  102. args.set(node.id, new uff.Argument(node.id));
  103. }
  104. }
  105. for (let i = graph.nodes.length - 1; i >= 0; i--) {
  106. const node = graph.nodes[i];
  107. if (node.operation === 'Const' && node.inputs.length === 0 && inputCountMap.get(node.id) === 1) {
  108. const fields = {};
  109. for (const field of node.fields) {
  110. fields[field.key] = field.value;
  111. }
  112. if (fields.dtype && fields.shape && fields.values) {
  113. const tensor = new uff.Tensor(fields.dtype.dtype, fields.shape, fields.values);
  114. args.set(node.id, new uff.Argument(node.id, tensor.type, tensor));
  115. graph.nodes.splice(i, 1);
  116. }
  117. }
  118. if (node.operation === 'Input' && node.inputs.length === 0) {
  119. const fields = {};
  120. for (const field of node.fields) {
  121. fields[field.key] = field.value;
  122. }
  123. const type = fields.dtype && fields.shape ? new uff.TensorType(fields.dtype.dtype, fields.shape) : null;
  124. args.set(node.id, new uff.Argument(node.id, type, null));
  125. }
  126. }
  127. for (const node of graph.nodes) {
  128. if (node.operation === 'Input') {
  129. this._inputs.push(new uff.Parameter(node.id, [ args.get(node.id) ]));
  130. continue;
  131. }
  132. if (node.operation === 'MarkOutput' && node.inputs.length === 1) {
  133. this._outputs.push(new uff.Parameter(node.id, [ args.get(node.inputs[0]) ]));
  134. continue;
  135. }
  136. this._nodes.push(new uff.Node(metadata, node, args));
  137. }
  138. }
  139. get name() {
  140. return this._name;
  141. }
  142. get inputs() {
  143. return this._inputs;
  144. }
  145. get outputs() {
  146. return this._outputs;
  147. }
  148. get nodes() {
  149. return this._nodes;
  150. }
  151. };
  152. uff.Parameter = class {
  153. constructor(name, args) {
  154. this._name = name;
  155. this._arguments = args;
  156. }
  157. get name() {
  158. return this._name;
  159. }
  160. get visible() {
  161. return true;
  162. }
  163. get arguments() {
  164. return this._arguments;
  165. }
  166. };
  167. uff.Argument = class {
  168. constructor(name, type, initializer) {
  169. if (typeof name !== 'string') {
  170. throw new uff.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  171. }
  172. this._name = name;
  173. this._type = type || null;
  174. this._initializer = initializer || null;
  175. }
  176. get name() {
  177. return this._name;
  178. }
  179. get type() {
  180. return this._type;
  181. }
  182. get initializer() {
  183. return this._initializer;
  184. }
  185. };
  186. uff.Node = class {
  187. constructor(metadata, node, args) {
  188. this._name = node.id;
  189. this._type = metadata.type(node.operation) || { name: node.operation };
  190. this._attributes = [];
  191. this._inputs = [];
  192. this._outputs = [];
  193. if (node.inputs && node.inputs.length > 0) {
  194. let inputIndex = 0;
  195. if (this._type && this._type.inputs) {
  196. for (const inputSchema of this._type.inputs) {
  197. if (inputIndex < node.inputs.length || inputSchema.optional !== true) {
  198. const inputCount = inputSchema.list ? (node.inputs.length - inputIndex) : 1;
  199. const inputArguments = node.inputs.slice(inputIndex, inputIndex + inputCount).map((id) => {
  200. return args.get(id);
  201. });
  202. inputIndex += inputCount;
  203. this._inputs.push(new uff.Parameter(inputSchema.name, inputArguments));
  204. }
  205. }
  206. }
  207. this._inputs.push(...node.inputs.slice(inputIndex).map((id, index) => {
  208. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  209. return new uff.Parameter(inputName, [ args.get(id) ]);
  210. }));
  211. }
  212. this._outputs.push(new uff.Parameter('output', [
  213. args.get(node.id)
  214. ]));
  215. for (const field of node.fields) {
  216. this._attributes.push(new uff.Attribute(metadata.attribute(node.operation, field.key), field.key, field.value));
  217. }
  218. }
  219. get name() {
  220. return this._name;
  221. }
  222. get type() {
  223. return this._type;
  224. }
  225. get inputs() {
  226. return this._inputs;
  227. }
  228. get outputs() {
  229. return this._outputs;
  230. }
  231. get attributes() {
  232. return this._attributes;
  233. }
  234. };
  235. uff.Attribute = class {
  236. constructor(metadata, name, value) {
  237. this._name = name;
  238. switch(value.type) {
  239. case 's': this._value = value.s; this._type = 'string'; break;
  240. case 's_list': this._value = value.s_list; this._type = 'string[]'; break;
  241. case 'd': this._value = value.d; this._type = 'float64'; break;
  242. case 'd_list': this._value = value.d_list.val; this._type = 'float64[]'; break;
  243. case 'b': this._value = value.b; this._type = 'boolean'; break;
  244. case 'b_list': this._value = value.b_list; this._type = 'boolean[]'; break;
  245. case 'i': this._value = value.i; this._type = 'int64'; break;
  246. case 'i_list': this._value = value.i_list.val; this._type = 'int64[]'; break;
  247. case 'blob': this._value = value.blob; break;
  248. case 'ref': this._value = value.ref; this._type = 'ref'; break;
  249. case 'dtype': this._value = new uff.TensorType(value.dtype, null).dataType; this._type = 'uff.DataType'; break;
  250. case 'dtype_list': this._value = value.dtype_list.map((type) => new uff.TensorType(type, null).dataType); this._type = 'uff.DataType[]'; break;
  251. case 'dim_orders': this._value = value.dim_orders; break;
  252. case 'dim_orders_list': this._value = value.dim_orders_list.val; break;
  253. default: throw new uff.Error("Unknown attribute '" + name + "' value '" + JSON.stringify(value) + "'.");
  254. }
  255. }
  256. get type() {
  257. return this._type;
  258. }
  259. get name() {
  260. return this._name;
  261. }
  262. get value() {
  263. return this._value;
  264. }
  265. get visible() {
  266. return true;
  267. }
  268. };
  269. uff.Tensor = class {
  270. constructor(dataType, shape, values) {
  271. this._type = new uff.TensorType(dataType, shape);
  272. switch (values.type) {
  273. case 'blob': this._data = values.blob; break;
  274. default: throw new uff.Error("Unknown values format '" + JSON.stringify(values.type) + "'.");
  275. }
  276. }
  277. get kind() {
  278. return 'Const';
  279. }
  280. get type() {
  281. return this._type;
  282. }
  283. get state() {
  284. return this._context().state;
  285. }
  286. get value() {
  287. const context = this._context();
  288. if (context.state) {
  289. return null;
  290. }
  291. context.limit = Number.MAX_SAFE_INTEGER;
  292. return this._decode(context, 0);
  293. }
  294. toString() {
  295. const context = this._context();
  296. if (context.state) {
  297. return '';
  298. }
  299. context.limit = 10000;
  300. const value = this._decode(context, 0);
  301. return JSON.stringify(value, null, 4);
  302. }
  303. _context() {
  304. const context = {};
  305. context.state = null;
  306. context.index = 0;
  307. context.count = 0;
  308. if (this._data == null) {
  309. context.state = 'Tensor data is empty.';
  310. return context;
  311. }
  312. if (this._data.length > 8 &&
  313. this._data[0] === 0x28 && this._data[1] === 0x2e && this._data[2] === 0x2e && this._data[3] === 0x2e &&
  314. this._data[this._data.length - 1] === 0x29 && this._data[this._data.length - 2] === 0x2e && this._data[this._data.length - 3] === 0x2e && this._data[this._data.length - 4] === 0x2e) {
  315. context.state = 'Tensor data is empty.';
  316. return context;
  317. }
  318. if (this._type.dataType === '?') {
  319. context.state = 'Tensor data type is unknown.';
  320. return context;
  321. }
  322. context.dataType = this._type.dataType;
  323. context.shape = this._type.shape.dimensions;
  324. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  325. return context;
  326. }
  327. _decode(context, dimension) {
  328. const shape = (context.shape.length == 0) ? [ 1 ] : context.shape;
  329. const size = shape[dimension];
  330. const results = [];
  331. if (dimension == shape.length - 1) {
  332. for (let i = 0; i < size; i++) {
  333. if (context.count > context.limit) {
  334. results.push('...');
  335. return results;
  336. }
  337. switch (context.dataType) {
  338. case 'int8':
  339. results.push(context.data.getInt8(context.index));
  340. context.index += 1;
  341. context.count++;
  342. break;
  343. case 'int16':
  344. results.push(context.data.getInt16(context.index));
  345. context.index += 2;
  346. context.count++;
  347. break;
  348. case 'int32':
  349. results.push(context.data.getInt32(context.index, true));
  350. context.index += 4;
  351. context.count++;
  352. break;
  353. case 'int64':
  354. results.push(context.data.getInt64(context.index, true));
  355. context.index += 8;
  356. context.count++;
  357. break;
  358. case 'float16':
  359. results.push(context.data.getFloat16(context.index, true));
  360. context.index += 2;
  361. context.count++;
  362. break;
  363. case 'float32':
  364. results.push(context.data.getFloat32(context.index, true));
  365. context.index += 4;
  366. context.count++;
  367. break;
  368. default:
  369. break;
  370. }
  371. }
  372. }
  373. else {
  374. for (let j = 0; j < size; j++) {
  375. if (context.count > context.limit) {
  376. results.push('...');
  377. return results;
  378. }
  379. results.push(this._decode(context, dimension + 1));
  380. }
  381. }
  382. if (context.shape.length == 0) {
  383. return results[0];
  384. }
  385. return results;
  386. }
  387. };
  388. uff.TensorType = class {
  389. constructor(dataType, shape) {
  390. switch (dataType) {
  391. case uff.proto.DataType.DT_INT8: this._dataType = 'int8'; break;
  392. case uff.proto.DataType.DT_INT16: this._dataType = 'int16'; break;
  393. case uff.proto.DataType.DT_INT32: this._dataType = 'int32'; break;
  394. case uff.proto.DataType.DT_INT64: this._dataType = 'int64'; break;
  395. case uff.proto.DataType.DT_FLOAT16: this._dataType = 'float16'; break;
  396. case uff.proto.DataType.DT_FLOAT32: this._dataType = 'float32'; break;
  397. case 7: this._dataType = '?'; break;
  398. default:
  399. throw new uff.Error("Unknown data type '" + JSON.stringify(dataType) + "'.");
  400. }
  401. this._shape = shape ? new uff.TensorShape(shape) : null;
  402. }
  403. get dataType() {
  404. return this._dataType;
  405. }
  406. get shape() {
  407. return this._shape;
  408. }
  409. toString() {
  410. return this.dataType + this._shape.toString();
  411. }
  412. };
  413. uff.TensorShape = class {
  414. constructor(shape) {
  415. if (shape.type !== 'i_list') {
  416. throw new uff.Error("Unknown shape format '" + JSON.stringify(shape.type) + "'.");
  417. }
  418. this._dimensions = shape.i_list.val;
  419. }
  420. get dimensions() {
  421. return this._dimensions;
  422. }
  423. toString() {
  424. if (!this._dimensions || this._dimensions.length == 0) {
  425. return '';
  426. }
  427. return '[' + this._dimensions.join(',') + ']';
  428. }
  429. };
  430. uff.Metadata = class {
  431. static open(context) {
  432. if (uff.Metadata._metadata) {
  433. return Promise.resolve(uff.Metadata._metadata);
  434. }
  435. return context.request('uff-metadata.json', 'utf-8', null).then((data) => {
  436. uff.Metadata._metadata = new uff.Metadata(data);
  437. return uff.Metadata._metadata;
  438. }).catch(() => {
  439. uff.Metadata._metadata = new uff.Metadata(null);
  440. return uff.Metadata._metadata;
  441. });
  442. }
  443. constructor(data) {
  444. this._map = new Map();
  445. this._attributeCache = new Map();
  446. if (data) {
  447. const metadata = JSON.parse(data);
  448. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  449. }
  450. }
  451. type(name) {
  452. return this._map.get(name);
  453. }
  454. attribute(type, name) {
  455. const key = type + ':' + name;
  456. if (!this._attributeCache.has(key)) {
  457. const schema = this.type(type);
  458. if (schema && schema.attributes && schema.attributes.length > 0) {
  459. for (const attribute of schema.attributes) {
  460. this._attributeCache.set(type + ':' + attribute.name, attribute);
  461. }
  462. }
  463. if (!this._attributeCache.has(key)) {
  464. this._attributeCache.set(key, null);
  465. }
  466. }
  467. return this._attributeCache.get(key);
  468. }
  469. };
  470. uff.Error = class extends Error {
  471. constructor(message) {
  472. super(message);
  473. this.name = 'Error loading UFF model.';
  474. }
  475. };
  476. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  477. module.exports.ModelFactory = uff.ModelFactory;
  478. }