om.js 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. // Experimental
  2. var om = om || {};
  3. var protobuf = protobuf || require('./protobuf');
  4. var base = base || require('./base');
  5. om.ModelFactory = class {
  6. match(context) {
  7. return om.File.open(context);
  8. }
  9. open(context, match) {
  10. const file = match;
  11. if (!file.model) {
  12. throw om.Error('File does not contain a model definition.');
  13. }
  14. return context.require('./om-proto').then(() => {
  15. let model = null;
  16. try {
  17. om.proto = protobuf.get('om').ge.proto;
  18. const reader = protobuf.BinaryReader.open(file.model);
  19. model = om.proto.ModelDef.decode(reader);
  20. }
  21. catch (error) {
  22. const message = error && error.message ? error.message : error.toString();
  23. throw new om.Error('File format is not ge.proto.ModelDef (' + message.replace(/\.$/, '') + ').');
  24. }
  25. return om.Metadata.open(context).then((metadata) => {
  26. return new om.Model(metadata, model, file.weights);
  27. });
  28. });
  29. }
  30. };
  31. om.Model = class {
  32. constructor(metadata, model, weights) {
  33. this._graphs = [];
  34. const context = { metadata: metadata, weights: weights };
  35. for (const graph of model.graph) {
  36. this._graphs.push(new om.Graph(context, graph));
  37. }
  38. }
  39. get format() {
  40. return 'DaVinci OM';
  41. }
  42. get graphs() {
  43. return this._graphs;
  44. }
  45. };
  46. om.Graph = class {
  47. constructor(context, graph) {
  48. this._name = graph.name;
  49. this._nodes = [];
  50. this._inputs = [];
  51. this._outputs = [];
  52. for (const op of graph.op) {
  53. if (op.type === 'Const') {
  54. continue;
  55. }
  56. const node = new om.Node(context, op, graph);
  57. this._nodes.push(node);
  58. }
  59. }
  60. get name() {
  61. return this._name;
  62. }
  63. get nodes() {
  64. return this._nodes;
  65. }
  66. get inputs() {
  67. return this._inputs;
  68. }
  69. get outputs() {
  70. return this._outputs;
  71. }
  72. };
  73. om.Node = class {
  74. constructor(context, op, graph) {
  75. this._name = op.name;
  76. this._type = context.metadata.type(op.type) || { name: op.type };
  77. this._inputs = [];
  78. this._outputs = [];
  79. this._attributes = [];
  80. this._chain = [];
  81. this._controlDependencies = [];
  82. this._device = null;
  83. if (op.input) {
  84. for (let i = 0; i < op.input.length; i++) {
  85. if (op.input[i] === '') {
  86. continue;
  87. }
  88. const pos = op.input[i].lastIndexOf(':');
  89. const name = pos === 0 ? 'internal_unnamed' : op.input[i].slice(0, pos);
  90. const src_index = op.input[i].slice(pos + 1);
  91. if (src_index === '-1') {
  92. this._controlDependencies.push(new om.Argument(name));
  93. continue;
  94. }
  95. const parameterName = this._type.inputs && i < this._type.inputs.length ? this._type.inputs[i].name : 'input' + (i === 0 ? '' : i.toString());
  96. const inputNode = graph.op.find(node => node.name === name);
  97. const desc = op.input_desc[i];
  98. const format = desc.layout;
  99. if (inputNode.type === 'Const' && inputNode.attr && inputNode.attr.value && inputNode.attr) {
  100. let shape = null;
  101. const value = inputNode.attr.value.t;
  102. if (value.desc.shape != null) {
  103. shape = value.desc.shape.dim;
  104. }
  105. if (value.desc.attr.origin_shape) {
  106. shape = value.desc.attr.origin_shape.list.i;
  107. }
  108. let data = null;
  109. if (value.data.length === 0) {
  110. if (context.weights == null) {
  111. data = null;
  112. }
  113. else if (value.desc.attr.merged_offset) {
  114. const offset = value.desc.attr.merged_offset.i;
  115. data = context.weights.slice(offset, offset + value.desc.weight_size);
  116. }
  117. else {
  118. const offset = value.desc.data_offset;
  119. data = context.weights.slice(offset, offset + value.desc.weight_size);
  120. }
  121. }
  122. else {
  123. data = value.data;
  124. }
  125. const dataType = om.Utility.dtype(value.desc.dtype);
  126. const tensorType = new om.TensorType(dataType, shape, format, value.desc.layout);
  127. const tensor = new om.Tensor('Constant', tensorType, data);
  128. const argument = new om.Argument(name, null, tensor);
  129. this._inputs.push(new om.Parameter(parameterName, true, [ argument ]));
  130. }
  131. else {
  132. const dataType = desc ? om.Utility.dtype(desc.dtype) : 'undefined';
  133. const shape = desc.shape ? desc.shape.dim : undefined;
  134. const tensorType = new om.TensorType(dataType, shape, format, null);
  135. const identifier = src_index === '0' ? name : name + ':' + src_index;
  136. const argument = new om.Argument(identifier, tensorType, null);
  137. this._inputs.push(new om.Parameter(parameterName, true, [ argument ]));
  138. }
  139. }
  140. }
  141. if (op.output_desc) {
  142. for (let i = 0; i < op.output_desc.length; i++) {
  143. const desc = op.output_desc[i];
  144. let shape = desc.shape ? desc.shape.dim : undefined;
  145. if (op.type === 'Data' || op.type === 'ImageData' || op.type === 'DynamicImageData') {
  146. shape = desc.shape ? desc.shape.dim : op.input_desc[0].shape.dim;
  147. }
  148. const dataType = om.Utility.dtype(desc.dtype);
  149. const format = desc.layout;
  150. const tensorType = new om.TensorType(dataType, shape, format);
  151. const identifier = i === 0 ? this._name : this._name + ':' + i;
  152. const argument = new om.Argument(identifier, tensorType, null);
  153. const outputName = this._type.outputs && i < this._type.outputs.length ? this._type.outputs[i].name : 'output' + (i === 0 ? '' : i.toString());
  154. this._outputs.push(new om.Parameter(outputName, true, [ argument ]));
  155. }
  156. }
  157. if (op.attr) {
  158. for (const attr of Object.entries(op.attr)) {
  159. const name = attr[0];
  160. const value = attr[1];
  161. if (name === 'device') {
  162. this._device = value;
  163. continue;
  164. }
  165. if (name === 'original_op_names') {
  166. continue;
  167. }
  168. if (name === 'relu_flag' && value.b) {
  169. this._chain.push(new om.Node(context, { type: 'ReLU' }, graph));
  170. continue;
  171. }
  172. const attribute = new om.Attribute(context, name, value);
  173. this._attributes.push(attribute);
  174. }
  175. }
  176. }
  177. get device() {
  178. return this._device;
  179. }
  180. get name() {
  181. return this._name || '';
  182. }
  183. get type() {
  184. return this._type;
  185. }
  186. get inputs() {
  187. return this._inputs;
  188. }
  189. get outputs() {
  190. return this._outputs;
  191. }
  192. get attributes() {
  193. return this._attributes;
  194. }
  195. get chain() {
  196. return this._chain;
  197. }
  198. get controlDependencies() {
  199. return this._controlDependencies;
  200. }
  201. };
  202. om.Attribute = class {
  203. constructor(context, name, value) {
  204. this._name = name;
  205. this._value = value;
  206. switch (value.value) {
  207. case 'i': {
  208. this._value = value.i;
  209. this._type = 'int64';
  210. break;
  211. }
  212. case 'f': {
  213. this._value = value.f;
  214. this._type = 'float32';
  215. break;
  216. }
  217. case 'b': {
  218. this._value = value.b;
  219. this._type = 'boolean';
  220. break;
  221. }
  222. case 'bt': {
  223. this._value = null;
  224. if (value.bt.length !== 0) {
  225. this._type = 'tensor';
  226. this._value = new om.Tensor('Constant', new om.TensorType('float32', [ value.bt.length / 4 ], null), value.bt);
  227. }
  228. break;
  229. }
  230. case 'dt': {
  231. this._type = 'DataType';
  232. this._value = om.Utility.dtype(value.dt.toNumber());
  233. break;
  234. }
  235. case 's': {
  236. if (typeof value.s === 'string') {
  237. this._value = value.s;
  238. }
  239. else if (value.s.filter(c => c <= 32 && c >= 128).length === 0) {
  240. this._value = om.Utility.decodeText(value.s);
  241. }
  242. else {
  243. this._value = value.s;
  244. }
  245. this._type = 'string';
  246. break;
  247. }
  248. case 'g': {
  249. this._type = 'graph';
  250. this._value = new om.Graph(context, value.g);
  251. break;
  252. }
  253. case 'func': {
  254. break;
  255. }
  256. case 'list': {
  257. const list = value.list;
  258. this._value = [];
  259. if (list.s && list.s.length > 0) {
  260. this._value = list.s.map(v => String.fromCharCode.apply(null, new Uint16Array(v))).join(', ');
  261. this._type = 'string[]';
  262. }
  263. else if (list.b && list.b.length > 0) {
  264. this._value = list.b;
  265. this._type = 'boolean[]';
  266. }
  267. else if (list.i && list.i.length > 0) {
  268. this._value = list.i;
  269. this._type = 'int64[]';
  270. }
  271. else if (list.f && list.f.length > 0) {
  272. this._value = list.f;
  273. this._type = 'float32[]';
  274. }
  275. else if (list.type && list.type.length > 0) {
  276. this._type = 'type[]';
  277. this._value = list.type.map((type) => om.Node.enum2Dtype(type) || '?');
  278. }
  279. else if (list.shape && list.shape.length > 0) {
  280. this._type = 'shape[]';
  281. this._value = list.shape.map((shape) => new om.TensorShape(shape));
  282. }
  283. break;
  284. }
  285. case undefined: {
  286. this._value = null;
  287. break;
  288. }
  289. }
  290. }
  291. get name() {
  292. return this._name;
  293. }
  294. get type() {
  295. return this._type;
  296. }
  297. get value() {
  298. return this._value;
  299. }
  300. get visible() {
  301. return true;
  302. }
  303. };
  304. om.Parameter = class {
  305. constructor(name, visible, args) {
  306. this._name = name;
  307. this._visible = visible;
  308. this._arguments = args;
  309. }
  310. get name() {
  311. return this._name;
  312. }
  313. get visible() {
  314. return this._visible;
  315. }
  316. get arguments() {
  317. return this._arguments;
  318. }
  319. };
  320. om.Argument = class {
  321. constructor(name, type, initializer) {
  322. if (typeof name !== 'string') {
  323. throw new om.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  324. }
  325. this._name = name;
  326. this._type = type || null;
  327. this._initializer = initializer || null;
  328. }
  329. get name() {
  330. return this._name;
  331. }
  332. get type() {
  333. if (this._initializer) {
  334. return this._initializer.type;
  335. }
  336. return this._type;
  337. }
  338. get initializer() {
  339. return this._initializer;
  340. }
  341. };
  342. om.Tensor = class {
  343. constructor(kind, type, value) {
  344. this._type = type;
  345. this._name = '';
  346. this._kind = kind;
  347. this._data = value;
  348. this._shape = type.shape.dimensions;
  349. }
  350. get name() {
  351. return this._name;
  352. }
  353. get type() {
  354. return this._type;
  355. }
  356. get kind() {
  357. return this._kind;
  358. }
  359. set kind(value) {
  360. this._kind = value;
  361. }
  362. get state() {
  363. return 'Tensor data not implemented.';
  364. }
  365. };
  366. om.TensorType = class {
  367. constructor(dataType, shape, format, denotation) {
  368. this._dataType = dataType;
  369. this._shape = new om.TensorShape(shape);
  370. const list = [];
  371. if (format) {
  372. list.push(format);
  373. }
  374. if (denotation && denotation !== format) {
  375. list.push(denotation);
  376. }
  377. this._denotation = list.join(' ');
  378. }
  379. get dataType() {
  380. return this._dataType;
  381. }
  382. set shape(dims) {
  383. this._shape = dims;
  384. }
  385. get shape() {
  386. return this._shape;
  387. }
  388. get denotation() {
  389. return this._denotation;
  390. }
  391. toString() {
  392. return this._dataType + this._shape.toString();
  393. }
  394. };
  395. om.TensorShape = class {
  396. constructor(shape) {
  397. this._shape = shape;
  398. }
  399. get dimensions() {
  400. return this._shape;
  401. }
  402. toString() {
  403. if (this._shape && Array.isArray(this._shape) && this._shape.length > 0) {
  404. return '[' + this._shape.map((dim) => dim ? dim.toString() : '?').join(',') + ']';
  405. }
  406. return '';
  407. }
  408. };
  409. om.File = class {
  410. static open(context) {
  411. const stream = context.stream;
  412. const signature = [ 0x49, 0x4D, 0x4F, 0x44 ]; // IMOD
  413. if (stream.length >= 256 && stream.peek(4).every((value, index) => value === signature[index])) {
  414. const reader = new base.BinaryReader(stream);
  415. return new om.File(reader);
  416. }
  417. return null;
  418. }
  419. constructor(reader) {
  420. this._reader = reader;
  421. }
  422. get model() {
  423. this._read();
  424. return this._model;
  425. }
  426. get weights() {
  427. this._read();
  428. return this._weights;
  429. }
  430. _read() {
  431. if (this._reader) {
  432. const reader = this._reader;
  433. delete this._reader;
  434. const decoder = new TextDecoder('utf-8');
  435. this.header = reader.uint32();
  436. const size = reader.uint32();
  437. this.version = reader.uint32();
  438. this.checksum = reader.read(64);
  439. reader.skip(4);
  440. this.is_encrypt = reader.byte();
  441. this.is_checksum = reader.byte();
  442. this.type = reader.byte(); // 0=IR model, 1=standard model, 2=OM Tiny model
  443. this.mode = reader.byte(); // 0=offline, 1=online
  444. this.name = decoder.decode(reader.read(32));
  445. this.ops = reader.uint32();
  446. this.userdefineinfo = reader.read(32);
  447. this.ir_version = reader.uint32();
  448. this.model_num = reader.uint32();
  449. this.platform_version = reader.read(20);
  450. this.platform_type = reader.byte();
  451. reader.seek(0);
  452. reader.skip(size);
  453. const partitions = new Array(reader.uint32());
  454. for (let i = 0; i < partitions.length; i++) {
  455. partitions[i] = {
  456. type: reader.uint32(),
  457. offset: reader.uint32(),
  458. size: reader.uint32()
  459. };
  460. }
  461. const offset = 256 + 4 + 12 * partitions.length;
  462. for (const partition of partitions) {
  463. reader.seek(offset + partition.offset);
  464. const buffer = reader.read(partition.size);
  465. switch (partition.type) {
  466. case 0: { // MODEL_DEF
  467. this._model = buffer;
  468. break;
  469. }
  470. case 1: { // MODEL_WEIGHT
  471. this._weights = buffer;
  472. break;
  473. }
  474. case 2: // TASK_INFO
  475. case 3: // TBE_KERNELS
  476. case 4: { // CUST_AICPU_KERNELS
  477. break;
  478. }
  479. case 5: { // DEVICE_CONFIG
  480. this.devices = new Map();
  481. const decoder = new TextDecoder('ascii');
  482. const reader = new base.BinaryReader(buffer);
  483. reader.uint32();
  484. for (let position = 4; position < partition.size; ) {
  485. const length = reader.uint32();
  486. const buffer = reader.read(length);
  487. const name = decoder.decode(buffer);
  488. const device = reader.uint32();
  489. this.devices.set(name, device);
  490. position += 4 + length + 4;
  491. }
  492. break;
  493. }
  494. default: {
  495. throw new om.Error("Unknown partition type '" + partition.type + "'.");
  496. }
  497. }
  498. }
  499. }
  500. }
  501. };
  502. om.Utility = class {
  503. static dtype(value) {
  504. om.Utility._types = om.Utility._types || [
  505. 'undefined', 'float32', 'float16', 'int8', 'uint8', 'int16', 'uint16', 'int32',
  506. 'int64', 'uint32', 'uint64', 'boolean', 'float64', 'string', 'dual_sub_int8', 'dual_sub_uint8',
  507. 'complex64', 'complex128', 'qint8', 'qint16', 'qint32', 'quint8', 'quint16', 'resource',
  508. 'stringref', 'dual', 'variant', 'bfloat16', 'int4', 'uint1', 'int2', 'uint2'
  509. ];
  510. if (value < om.Utility._types.length) {
  511. return om.Utility._types[value];
  512. }
  513. throw new om.Error("Unknown dtype '" + value + "'.");
  514. }
  515. static decodeText(value) {
  516. om.Utility._textDecoder = om.Utility._textDecoder || new TextDecoder('utf-8');
  517. return om.Utility._textDecoder.decode(value);
  518. }
  519. };
  520. om.Metadata = class {
  521. static open(context) {
  522. if (om.Metadata._metadata) {
  523. return Promise.resolve(om.Metadata._metadata);
  524. }
  525. return context.request('om-metadata.json', 'utf-8', null).then((data) => {
  526. om.Metadata._metadata = new om.Metadata(data);
  527. return om.Metadata._metadata;
  528. }).catch(() => {
  529. om.Metadata._metadata = new om.Metadata(null);
  530. return om.Metadata._metadata;
  531. });
  532. }
  533. constructor(data) {
  534. this._map = new Map();
  535. this._attributes = new Map();
  536. if (data) {
  537. const metadata = JSON.parse(data);
  538. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  539. }
  540. }
  541. type(name) {
  542. return this._map.get(name);
  543. }
  544. attribute(type, name) {
  545. const key = type + ':' + name;
  546. if (!this._attributes.has(key)) {
  547. const schema = this.type(type);
  548. if (schema && schema.attributes && schema.attributes.length > 0) {
  549. for (const attribute of schema.attributes) {
  550. this._attributes.set(type + ':' + attribute.name, attribute);
  551. }
  552. }
  553. if (!this._attributes.has(key)) {
  554. this._attributes.set(key, null);
  555. }
  556. }
  557. return this._attributes.get(key);
  558. }
  559. };
  560. om.Error = class extends Error {
  561. constructor(message) {
  562. super(message);
  563. this.name = 'Error loading DaVinci model.';
  564. }
  565. };
  566. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  567. module.exports.ModelFactory = om.ModelFactory;
  568. }