uff.js 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var uff = uff || {};
  4. var protobuf = protobuf || require('./protobuf');
  5. uff.ModelFactory = class {
  6. match(context) {
  7. const identifier = context.identifier;
  8. const extension = identifier.split('.').pop().toLowerCase();
  9. if (extension === 'uff' || extension === 'pb') {
  10. const tags = context.tags('pb');
  11. if (tags.size > 0 &&
  12. tags.has(1) && tags.get(1) === 0 &&
  13. tags.has(2) && tags.get(2) === 0 &&
  14. tags.has(3) && tags.get(3) === 2 &&
  15. tags.has(4) && tags.get(4) === 2 &&
  16. tags.has(5) && tags.get(5) === 2) {
  17. return true;
  18. }
  19. }
  20. if (extension === 'pbtxt' || identifier.toLowerCase().endsWith('.uff.txt')) {
  21. const tags = context.tags('pbtxt');
  22. if (tags.has('version') && tags.has('descriptors') && tags.has('graphs')) {
  23. return true;
  24. }
  25. }
  26. return false;
  27. }
  28. open(context) {
  29. return context.require('./uff-proto').then(() => {
  30. let meta_graph = null;
  31. const identifier = context.identifier;
  32. const extension = identifier.split('.').pop().toLowerCase();
  33. if (extension === 'pbtxt' || identifier.toLowerCase().endsWith('.uff.txt')) {
  34. try {
  35. uff.proto = protobuf.get('uff').uff;
  36. const stream = context.stream;
  37. const reader = protobuf.TextReader.open(stream);
  38. meta_graph = uff.proto.MetaGraph.decodeText(reader);
  39. }
  40. catch (error) {
  41. throw new uff.Error('File text format is not uff.MetaGraph (' + error.message + ').');
  42. }
  43. }
  44. else {
  45. try {
  46. uff.proto = protobuf.get('uff').uff;
  47. const stream = context.stream;
  48. const reader = protobuf.BinaryReader.open(stream);
  49. meta_graph = uff.proto.MetaGraph.decode(reader);
  50. }
  51. catch (error) {
  52. const message = error && error.message ? error.message : error.toString();
  53. throw new uff.Error('File format is not uff.MetaGraph (' + message.replace(/\.$/, '') + ').');
  54. }
  55. }
  56. return uff.Metadata.open(context).then((metadata) => {
  57. return new uff.Model(metadata, meta_graph);
  58. });
  59. });
  60. }
  61. };
  62. uff.Model = class {
  63. constructor(metadata, meta_graph) {
  64. this._version = meta_graph.version;
  65. this._imports = meta_graph.descriptors.map((descriptor) => descriptor.id + ' v' + descriptor.version.toString());
  66. const references = new Map(meta_graph.referenced_data.map((item) => [ item.key, item.value ]));
  67. for (const graph of meta_graph.graphs) {
  68. for (const node of graph.nodes) {
  69. for (const field of node.fields) {
  70. if (field.value.type === 'ref' && references.has(field.value.ref)) {
  71. field.value = references.get(field.value.ref);
  72. }
  73. }
  74. }
  75. }
  76. this._graphs = meta_graph.graphs.map((graph) => new uff.Graph(metadata, graph));
  77. }
  78. get format() {
  79. return 'UFF' + (this._version ? ' v' + this._version.toString() : '');
  80. }
  81. get imports() {
  82. return this._imports;
  83. }
  84. get graphs() {
  85. return this._graphs;
  86. }
  87. };
  88. uff.Graph = class {
  89. constructor(metadata, graph) {
  90. this._name = graph.id;
  91. this._inputs = [];
  92. this._outputs = [];
  93. this._nodes = [];
  94. const args = new Map();
  95. const inputCountMap = new Map();
  96. for (const node of graph.nodes) {
  97. for (const input of node.inputs) {
  98. inputCountMap.set(input, inputCountMap.has(input) ? inputCountMap.get(input) + 1 : 1);
  99. args.set(input, new uff.Argument(input));
  100. }
  101. if (!args.has(node.id)) {
  102. args.set(node.id, new uff.Argument(node.id));
  103. }
  104. }
  105. for (let i = graph.nodes.length - 1; i >= 0; i--) {
  106. const node = graph.nodes[i];
  107. if (node.operation === 'Const' && node.inputs.length === 0 && inputCountMap.get(node.id) === 1) {
  108. const fields = {};
  109. for (const field of node.fields) {
  110. fields[field.key] = field.value;
  111. }
  112. if (fields.dtype && fields.shape && fields.values) {
  113. const tensor = new uff.Tensor(fields.dtype.dtype, fields.shape, fields.values);
  114. args.set(node.id, new uff.Argument(node.id, tensor.type, tensor));
  115. graph.nodes.splice(i, 1);
  116. }
  117. }
  118. if (node.operation === 'Input' && node.inputs.length === 0) {
  119. const fields = {};
  120. for (const field of node.fields) {
  121. fields[field.key] = field.value;
  122. }
  123. const type = fields.dtype && fields.shape ? new uff.TensorType(fields.dtype.dtype, fields.shape) : null;
  124. args.set(node.id, new uff.Argument(node.id, type, null));
  125. }
  126. }
  127. for (const node of graph.nodes) {
  128. if (node.operation === 'Input') {
  129. this._inputs.push(new uff.Parameter(node.id, [ args.get(node.id) ]));
  130. continue;
  131. }
  132. if (node.operation === 'MarkOutput' && node.inputs.length === 1) {
  133. this._outputs.push(new uff.Parameter(node.id, [ args.get(node.inputs[0]) ]));
  134. continue;
  135. }
  136. this._nodes.push(new uff.Node(metadata, node, args));
  137. }
  138. }
  139. get name() {
  140. return this._name;
  141. }
  142. get inputs() {
  143. return this._inputs;
  144. }
  145. get outputs() {
  146. return this._outputs;
  147. }
  148. get nodes() {
  149. return this._nodes;
  150. }
  151. };
  152. uff.Parameter = class {
  153. constructor(name, args) {
  154. this._name = name;
  155. this._arguments = args;
  156. }
  157. get name() {
  158. return this._name;
  159. }
  160. get visible() {
  161. return true;
  162. }
  163. get arguments() {
  164. return this._arguments;
  165. }
  166. };
  167. uff.Argument = class {
  168. constructor(name, type, initializer) {
  169. if (typeof name !== 'string') {
  170. throw new uff.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  171. }
  172. this._name = name;
  173. this._type = type || null;
  174. this._initializer = initializer || null;
  175. }
  176. get name() {
  177. return this._name;
  178. }
  179. get type() {
  180. return this._type;
  181. }
  182. get initializer() {
  183. return this._initializer;
  184. }
  185. };
  186. uff.Node = class {
  187. constructor(metadata, node, args) {
  188. this._name = node.id;
  189. this._operation = node.operation;
  190. this._metadata = metadata.type(node.operation);
  191. this._attributes = [];
  192. this._inputs = [];
  193. this._outputs = [];
  194. const schema = metadata.type(node.operation);
  195. if (node.inputs && node.inputs.length > 0) {
  196. let inputIndex = 0;
  197. if (schema && schema.inputs) {
  198. for (const inputSchema of schema.inputs) {
  199. if (inputIndex < node.inputs.length || inputSchema.optional !== true) {
  200. const inputCount = inputSchema.list ? (node.inputs.length - inputIndex) : 1;
  201. const inputArguments = node.inputs.slice(inputIndex, inputIndex + inputCount).map((id) => {
  202. return args.get(id);
  203. });
  204. inputIndex += inputCount;
  205. this._inputs.push(new uff.Parameter(inputSchema.name, inputArguments));
  206. }
  207. }
  208. }
  209. this._inputs.push(...node.inputs.slice(inputIndex).map((id, index) => {
  210. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  211. return new uff.Parameter(inputName, [ args.get(id) ]);
  212. }));
  213. }
  214. this._outputs.push(new uff.Parameter('output', [
  215. args.get(node.id)
  216. ]));
  217. for (const field of node.fields) {
  218. this._attributes.push(new uff.Attribute(metadata.attribute(this._operation, field.key), field.key, field.value));
  219. }
  220. }
  221. get name() {
  222. return this._name;
  223. }
  224. get type() {
  225. return this._operation;
  226. }
  227. get metadata() {
  228. return this._metadata;
  229. }
  230. get inputs() {
  231. return this._inputs;
  232. }
  233. get outputs() {
  234. return this._outputs;
  235. }
  236. get attributes() {
  237. return this._attributes;
  238. }
  239. };
  240. uff.Attribute = class {
  241. constructor(metadata, name, value) {
  242. this._name = name;
  243. switch(value.type) {
  244. case 's': this._value = value.s; this._type = 'string'; break;
  245. case 's_list': this._value = value.s_list; this._type = 'string[]'; break;
  246. case 'd': this._value = value.d; this._type = 'float64'; break;
  247. case 'd_list': this._value = value.d_list.val; this._type = 'float64[]'; break;
  248. case 'b': this._value = value.b; this._type = 'boolean'; break;
  249. case 'b_list': this._value = value.b_list; this._type = 'boolean[]'; break;
  250. case 'i': this._value = value.i; this._type = 'int64'; break;
  251. case 'i_list': this._value = value.i_list.val; this._type = 'int64[]'; break;
  252. case 'blob': this._value = value.blob; break;
  253. case 'ref': this._value = value.ref; this._type = 'ref'; break;
  254. case 'dtype': this._value = new uff.TensorType(value.dtype, null).dataType; this._type = 'uff.DataType'; break;
  255. case 'dtype_list': this._value = value.dtype_list.map((type) => new uff.TensorType(type, null).dataType); this._type = 'uff.DataType[]'; break;
  256. case 'dim_orders': this._value = value.dim_orders; break;
  257. case 'dim_orders_list': this._value = value.dim_orders_list.val; break;
  258. default: throw new uff.Error("Unknown attribute '" + name + "' value '" + JSON.stringify(value) + "'.");
  259. }
  260. }
  261. get type() {
  262. return this._type;
  263. }
  264. get name() {
  265. return this._name;
  266. }
  267. get value() {
  268. return this._value;
  269. }
  270. get visible() {
  271. return true;
  272. }
  273. };
  274. uff.Tensor = class {
  275. constructor(dataType, shape, values) {
  276. this._type = new uff.TensorType(dataType, shape);
  277. switch (values.type) {
  278. case 'blob': this._data = values.blob; break;
  279. default: throw new uff.Error("Unknown values format '" + JSON.stringify(values.type) + "'.");
  280. }
  281. }
  282. get kind() {
  283. return 'Const';
  284. }
  285. get type() {
  286. return this._type;
  287. }
  288. get state() {
  289. return this._context().state;
  290. }
  291. get value() {
  292. const context = this._context();
  293. if (context.state) {
  294. return null;
  295. }
  296. context.limit = Number.MAX_SAFE_INTEGER;
  297. return this._decode(context, 0);
  298. }
  299. toString() {
  300. const context = this._context();
  301. if (context.state) {
  302. return '';
  303. }
  304. context.limit = 10000;
  305. const value = this._decode(context, 0);
  306. return JSON.stringify(value, null, 4);
  307. }
  308. _context() {
  309. const context = {};
  310. context.state = null;
  311. context.index = 0;
  312. context.count = 0;
  313. if (this._data == null) {
  314. context.state = 'Tensor data is empty.';
  315. return context;
  316. }
  317. if (this._data.length > 8 &&
  318. this._data[0] === 0x28 && this._data[1] === 0x2e && this._data[2] === 0x2e && this._data[3] === 0x2e &&
  319. this._data[this._data.length - 1] === 0x29 && this._data[this._data.length - 2] === 0x2e && this._data[this._data.length - 3] === 0x2e && this._data[this._data.length - 4] === 0x2e) {
  320. context.state = 'Tensor data is empty.';
  321. return context;
  322. }
  323. if (this._type.dataType === '?') {
  324. context.state = 'Tensor data type is unknown.';
  325. return context;
  326. }
  327. context.dataType = this._type.dataType;
  328. context.shape = this._type.shape.dimensions;
  329. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  330. return context;
  331. }
  332. _decode(context, dimension) {
  333. const shape = (context.shape.length == 0) ? [ 1 ] : context.shape;
  334. const size = shape[dimension];
  335. const results = [];
  336. if (dimension == shape.length - 1) {
  337. for (let i = 0; i < size; i++) {
  338. if (context.count > context.limit) {
  339. results.push('...');
  340. return results;
  341. }
  342. switch (context.dataType) {
  343. case 'int8':
  344. results.push(context.data.getInt8(context.index));
  345. context.index += 1;
  346. context.count++;
  347. break;
  348. case 'int16':
  349. results.push(context.data.getInt16(context.index));
  350. context.index += 2;
  351. context.count++;
  352. break;
  353. case 'int32':
  354. results.push(context.data.getInt32(context.index, true));
  355. context.index += 4;
  356. context.count++;
  357. break;
  358. case 'int64':
  359. results.push(context.data.getInt64(context.index, true));
  360. context.index += 8;
  361. context.count++;
  362. break;
  363. case 'float16':
  364. results.push(context.data.getFloat16(context.index, true));
  365. context.index += 2;
  366. context.count++;
  367. break;
  368. case 'float32':
  369. results.push(context.data.getFloat32(context.index, true));
  370. context.index += 4;
  371. context.count++;
  372. break;
  373. default:
  374. break;
  375. }
  376. }
  377. }
  378. else {
  379. for (let j = 0; j < size; j++) {
  380. if (context.count > context.limit) {
  381. results.push('...');
  382. return results;
  383. }
  384. results.push(this._decode(context, dimension + 1));
  385. }
  386. }
  387. if (context.shape.length == 0) {
  388. return results[0];
  389. }
  390. return results;
  391. }
  392. };
  393. uff.TensorType = class {
  394. constructor(dataType, shape) {
  395. switch (dataType) {
  396. case uff.proto.DataType.DT_INT8: this._dataType = 'int8'; break;
  397. case uff.proto.DataType.DT_INT16: this._dataType = 'int16'; break;
  398. case uff.proto.DataType.DT_INT32: this._dataType = 'int32'; break;
  399. case uff.proto.DataType.DT_INT64: this._dataType = 'int64'; break;
  400. case uff.proto.DataType.DT_FLOAT16: this._dataType = 'float16'; break;
  401. case uff.proto.DataType.DT_FLOAT32: this._dataType = 'float32'; break;
  402. case 7: this._dataType = '?'; break;
  403. default:
  404. throw new uff.Error("Unknown data type '" + JSON.stringify(dataType) + "'.");
  405. }
  406. this._shape = shape ? new uff.TensorShape(shape) : null;
  407. }
  408. get dataType() {
  409. return this._dataType;
  410. }
  411. get shape() {
  412. return this._shape;
  413. }
  414. toString() {
  415. return this.dataType + this._shape.toString();
  416. }
  417. };
  418. uff.TensorShape = class {
  419. constructor(shape) {
  420. if (shape.type !== 'i_list') {
  421. throw new uff.Error("Unknown shape format '" + JSON.stringify(shape.type) + "'.");
  422. }
  423. this._dimensions = shape.i_list.val;
  424. }
  425. get dimensions() {
  426. return this._dimensions;
  427. }
  428. toString() {
  429. if (!this._dimensions || this._dimensions.length == 0) {
  430. return '';
  431. }
  432. return '[' + this._dimensions.join(',') + ']';
  433. }
  434. };
  435. uff.Metadata = class {
  436. static open(context) {
  437. if (uff.Metadata._metadata) {
  438. return Promise.resolve(uff.Metadata._metadata);
  439. }
  440. return context.request('uff-metadata.json', 'utf-8', null).then((data) => {
  441. uff.Metadata._metadata = new uff.Metadata(data);
  442. return uff.Metadata._metadata;
  443. }).catch(() => {
  444. uff.Metadata._metadata = new uff.Metadata(null);
  445. return uff.Metadata._metadata;
  446. });
  447. }
  448. constructor(data) {
  449. this._map = new Map();
  450. this._attributeCache = new Map();
  451. if (data) {
  452. const metadata = JSON.parse(data);
  453. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  454. }
  455. }
  456. type(name) {
  457. return this._map.get(name);
  458. }
  459. attribute(type, name) {
  460. const key = type + ':' + name;
  461. if (!this._attributeCache.has(key)) {
  462. const schema = this.type(type);
  463. if (schema && schema.attributes && schema.attributes.length > 0) {
  464. for (const attribute of schema.attributes) {
  465. this._attributeCache.set(type + ':' + attribute.name, attribute);
  466. }
  467. }
  468. if (!this._attributeCache.has(key)) {
  469. this._attributeCache.set(key, null);
  470. }
  471. }
  472. return this._attributeCache.get(key);
  473. }
  474. };
  475. uff.Error = class extends Error {
  476. constructor(message) {
  477. super(message);
  478. this.name = 'Error loading UFF model.';
  479. }
  480. };
  481. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  482. module.exports.ModelFactory = uff.ModelFactory;
  483. }