view.js 66 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489
  1. /* jshint esversion: 6 */
  2. var view = view || {};
  3. var base = base || require('./base');
  4. var zip = zip || require('./zip');
  5. var gzip = gzip || require('./gzip');
  6. var tar = tar || require('./tar');
  7. var protobuf = protobuf || require('./protobuf');
  8. var d3 = d3 || require('d3');
  9. var dagre = dagre || require('dagre');
  10. var sidebar = sidebar || require('./view-sidebar');
  11. var grapher = grapher || require('./view-grapher');
  12. view.View = class {
  13. constructor(host) {
  14. this._host = host;
  15. this._host.initialize(this).then(() => {
  16. this._model = null;
  17. this._selection = [];
  18. this._sidebar = new sidebar.Sidebar(this._host);
  19. this._showAttributes = false;
  20. this._showInitializers = true;
  21. this._showNames = false;
  22. this._showHorizontal = false;
  23. this._searchText = '';
  24. this._modelFactoryService = new view.ModelFactoryService(this._host);
  25. this._host.document.getElementById('zoom-in-button').addEventListener('click', () => {
  26. this.zoomIn();
  27. });
  28. this._host.document.getElementById('zoom-out-button').addEventListener('click', () => {
  29. this.zoomOut();
  30. });
  31. this._host.document.getElementById('toolbar').addEventListener('mousewheel', (e) => {
  32. this._preventZoom(e);
  33. });
  34. this._host.document.getElementById('sidebar').addEventListener('mousewheel', (e) => {
  35. this._preventZoom(e);
  36. });
  37. this._host.document.addEventListener('keydown', () => {
  38. this.clearSelection();
  39. });
  40. if (this._host.environment('zoom') == 'scroll') {
  41. this._host.document.getElementById('graph').addEventListener('mousewheel', (e) => {
  42. this._mouseWheelHandler(e);
  43. });
  44. this._host.document.getElementById('graph').addEventListener('scroll', (e) => {
  45. this._scrollHandler(e);
  46. });
  47. this._host.document.getElementById('graph').addEventListener('gesturestart', (e) => {
  48. e.preventDefault();
  49. this._gestureStartZoom = this._zoom;
  50. }, false);
  51. this._host.document.getElementById('graph').addEventListener('gesturechange', (e) => {
  52. e.preventDefault();
  53. this._updateZoom(this._gestureStartZoom * e.scale, e);
  54. }, false);
  55. this._host.document.getElementById('graph').addEventListener('gestureend', (e) => {
  56. e.preventDefault();
  57. this._updateZoom(this._gestureStartZoom * e.scale, e);
  58. }, false);
  59. }
  60. this._host.start();
  61. }).catch((err) => {
  62. this.error(err, null, null);
  63. });
  64. }
  65. show(page) {
  66. if (!page) {
  67. page = (!this._model && !this._activeGraph) ? 'welcome' : 'default';
  68. }
  69. this._host.screen(page);
  70. if (this._sidebar) {
  71. this._sidebar.close();
  72. }
  73. this._host.document.body.setAttribute('class', page);
  74. }
  75. cut() {
  76. this._host.document.execCommand('cut');
  77. }
  78. copy() {
  79. this._host.document.execCommand('copy');
  80. }
  81. paste() {
  82. this._host.document.execCommand('paste');
  83. }
  84. selectAll() {
  85. this._host.document.execCommand('selectall');
  86. }
  87. find() {
  88. if (this._activeGraph) {
  89. this.clearSelection();
  90. const graphElement = document.getElementById('canvas');
  91. const view = new sidebar.FindSidebar(this._host, graphElement, this._activeGraph);
  92. view.on('search-text-changed', (sender, text) => {
  93. this._searchText = text;
  94. });
  95. view.on('select', (sender, selection) => {
  96. this._sidebar.close();
  97. this.select(selection);
  98. });
  99. this._sidebar.open(view.content, 'Find');
  100. view.focus(this._searchText);
  101. }
  102. }
  103. toggleAttributes() {
  104. this._showAttributes = !this._showAttributes;
  105. this._reload();
  106. }
  107. get showAttributes() {
  108. return this._showAttributes;
  109. }
  110. toggleInitializers() {
  111. this._showInitializers = !this._showInitializers;
  112. this._reload();
  113. }
  114. get showInitializers() {
  115. return this._showInitializers;
  116. }
  117. toggleNames() {
  118. this._showNames = !this._showNames;
  119. this._reload();
  120. }
  121. get showNames() {
  122. return this._showNames;
  123. }
  124. toggleDirection() {
  125. this._showHorizontal = !this._showHorizontal;
  126. this._reload();
  127. }
  128. get showHorizontal() {
  129. return this._showHorizontal;
  130. }
  131. _reload() {
  132. this.show('welcome spinner');
  133. if (this._model && this._activeGraph) {
  134. this._updateGraph(this._model, this._activeGraph).catch((error) => {
  135. if (error) {
  136. this.error(error, 'Graph update failed.', 'welcome');
  137. }
  138. });
  139. }
  140. }
  141. _timeout(time) {
  142. return new Promise((resolve) => {
  143. setTimeout(() => { resolve(); }, time);
  144. });
  145. }
  146. zoomIn() {
  147. switch (this._host.environment('zoom')) {
  148. case 'scroll':
  149. this._updateZoom(this._zoom * 1.05);
  150. break;
  151. case 'd3':
  152. if (this._zoom) {
  153. this._zoom.scaleBy(d3.select(this._host.document.getElementById('canvas')), 1.2);
  154. }
  155. break;
  156. }
  157. }
  158. zoomOut() {
  159. switch (this._host.environment('zoom')) {
  160. case 'scroll':
  161. this._updateZoom(this._zoom * 0.95);
  162. break;
  163. case 'd3':
  164. if (this._zoom) {
  165. this._zoom.scaleBy(d3.select(this._host.document.getElementById('canvas')), 0.8);
  166. }
  167. break;
  168. }
  169. }
  170. resetZoom() {
  171. switch (this._host.environment('zoom')) {
  172. case 'scroll':
  173. this._updateZoom(1);
  174. break;
  175. case 'd3':
  176. if (this._zoom) {
  177. this._zoom.scaleTo(d3.select(this._host.document.getElementById('canvas')), 1);
  178. }
  179. break;
  180. }
  181. }
  182. _preventZoom(e) {
  183. if (e.shiftKey || e.ctrlKey) {
  184. e.preventDefault();
  185. }
  186. }
  187. _updateZoom(zoom, e) {
  188. const container = this._host.document.getElementById('graph');
  189. const min = Math.min(Math.max(container.clientHeight / this._height, 0.2), 1);
  190. zoom = Math.min(zoom, 2);
  191. zoom = Math.max(min, zoom);
  192. const scrollLeft = this._scrollLeft || container.scrollLeft;
  193. const scrollTop = this._scrollTop || container.scrollTop;
  194. const x = (e ? e.pageX : (container.clientWidth / 2)) + scrollLeft;
  195. const y = (e ? e.pageY : (container.clientHeight / 2)) + scrollTop;
  196. const graph = this._host.document.getElementById('canvas');
  197. graph.style.width = zoom * this._width;
  198. graph.style.height = zoom * this._height;
  199. this._scrollLeft = ((x * zoom) / this._zoom) - (x - scrollLeft);
  200. this._scrollTop = ((y * zoom) / this._zoom) - (y - scrollTop);
  201. this._scrollLeft = Math.max(0, this._scrollLeft);
  202. this._scrollTop = Math.max(0, this._scrollTop);
  203. container.scrollLeft = this._scrollLeft;
  204. container.scrollTop = this._scrollTop;
  205. this._zoom = zoom;
  206. }
  207. _mouseWheelHandler(e) {
  208. if (e.shiftKey || e.ctrlKey) {
  209. this._updateZoom(this._zoom + (e.wheelDelta * 1.0 / 4000.0), e);
  210. e.preventDefault();
  211. }
  212. }
  213. _scrollHandler(e) {
  214. if (this._scrollLeft && e.target.scrollLeft !== Math.floor(this._scrollLeft)) {
  215. delete this._scrollLeft;
  216. }
  217. if (this._scrollTop && e.target.scrollTop !== Math.floor(this._scrollTop)) {
  218. delete this._scrollTop;
  219. }
  220. }
  221. select(selection) {
  222. this.clearSelection();
  223. if (selection && selection.length > 0) {
  224. const graphElement = this._host.document.getElementById('canvas');
  225. const graphRect = graphElement.getBoundingClientRect();
  226. let x = 0;
  227. let y = 0;
  228. for (const element of selection) {
  229. element.classList.add('select');
  230. this._selection.push(element);
  231. const transform = element.transform.baseVal.consolidate();
  232. const box = element.getBBox();
  233. const ex = transform ? transform.matrix.e : box.x + (box.width / 2);
  234. const ey = transform ? transform.matrix.f : box.y + (box.height / 2);
  235. x += ex;
  236. y += ey;
  237. }
  238. x = x / selection.length;
  239. y = y / selection.length;
  240. this._zoom.transform(d3.select(graphElement), d3.zoomIdentity.translate((graphRect.width / 2) - x, (graphRect.height / 2) - y));
  241. }
  242. }
  243. clearSelection() {
  244. while (this._selection.length > 0) {
  245. const element = this._selection.pop();
  246. element.classList.remove('select');
  247. }
  248. }
  249. error(err, name, screen) {
  250. if (this._sidebar) {
  251. this._sidebar.close();
  252. }
  253. this._host.exception(err, false);
  254. const knowns = [
  255. { name: 'Error', message: /^EACCES: permission denied/, url: 'https://github.com/lutzroeder/netron/issues/504' },
  256. { name: 'Error loading Darknet model.', message: /^Cannot read property/, url: 'https://github.com/lutzroeder/netron/issues/539' },
  257. { name: 'Error loading Keras model.', message: /^Invalid argument identifier/, url: 'https://github.com/lutzroeder/netron/issues/540' },
  258. { name: 'Error loading Darknet model.', message: /^Invalid tensor shape/, url: 'https://github.com/lutzroeder/netron/issues/541' },
  259. { name: 'Error loading PyTorch model.', message: /^File does not contain root module or state dictionary/, url: 'https://github.com/lutzroeder/netron/issues/543' },
  260. { name: 'Error loading PyTorch model.', message: /^Module does not contain modules/, url: 'https://github.com/lutzroeder/netron/issues/544' },
  261. { name: 'Error loading PyTorch model.', message: /^Failed to resolve module/, url: 'https://github.com/lutzroeder/netron/issues/545' },
  262. { name: 'Error loading PyTorch model.', message: /^Unsupported function/, url: 'https://github.com/lutzroeder/netron/issues/546' },
  263. { name: 'Error loading PyTorch model.', message: /^Unsupported uninitialized argument/, url: 'https://github.com/lutzroeder/netron/issues/547' },
  264. { name: 'Error loading Keras model.', message: /^Unsupported data object header version/, url: 'https://github.com/lutzroeder/netron/issues/548' },
  265. { name: 'Error loading ONNX model.', message: /^File format is not onnx\.ModelProto/, url: 'https://github.com/lutzroeder/netron/issues/549' },
  266. { name: 'Error loading model.', message: /^Unsupported file content \(/, url: 'https://github.com/lutzroeder/netron/issues/550' },
  267. { name: 'Error', message: /^EPERM: operation not permitted/, url: 'https://github.com/lutzroeder/netron/issues/551' },
  268. { name: 'Error loading UFF model.', message: /^Unknown data type/, url: 'https://github.com/lutzroeder/netron/issues/561' },
  269. { name: 'RangeError', message: /^Offset is outside the bounds of the DataView/, url: 'https://github.com/lutzroeder/netron/issues/563' },
  270. { name: 'RangeError', message: /^start offset of Int32Array/, url: 'https://github.com/lutzroeder/netron/issues/565' }
  271. ];
  272. const known = knowns.find((known) => err.name === known.name && err.message.match(known.message));
  273. const message = (name ? err.toString() : err.message) + (known ? '\n\nPlease provide information about this issue at ' + known.url + '.' : '');
  274. name = name || err.name;
  275. this._host.error(name, message);
  276. this.show(screen !== undefined ? screen : 'welcome');
  277. if (known) {
  278. this._host.openURL(known.url);
  279. }
  280. }
  281. accept(file) {
  282. return this._modelFactoryService.accept(file);
  283. }
  284. open(context) {
  285. this._host.event('Model', 'Open', 'Size', context.buffer.length);
  286. this._sidebar.close();
  287. return this._timeout(2).then(() => {
  288. return this._modelFactoryService.open(context).then((model) => {
  289. const format = model.format;
  290. if (format) {
  291. this._host.event('Model', 'Format', format + (model.producer ? ' (' + model.producer + ')' : ''));
  292. }
  293. return this._timeout(20).then(() => {
  294. const graph = model.graphs.length > 0 ? model.graphs[0] : null;
  295. return this._updateGraph(model, graph);
  296. });
  297. });
  298. });
  299. }
  300. _updateActiveGraph(name) {
  301. this._sidebar.close();
  302. if (this._model) {
  303. const model = this._model;
  304. const graph = model.graphs.filter(graph => name == graph.name).shift();
  305. if (graph) {
  306. this.show('welcome spinner');
  307. this._timeout(200).then(() => {
  308. return this._updateGraph(model, graph).catch((error) => {
  309. if (error) {
  310. this.error(error, 'Graph update failed.', 'welcome');
  311. }
  312. });
  313. });
  314. }
  315. }
  316. }
  317. _updateGraph(model, graph) {
  318. return this._timeout(100).then(() => {
  319. if (graph && graph != this._activeGraph) {
  320. const nodes = graph.nodes;
  321. if (nodes.length > 1400) {
  322. if (!this._host.confirm('Large model detected.', 'This graph contains a large number of nodes and might take a long time to render. Do you want to continue?')) {
  323. this._host.event('Graph', 'Render', 'Skip', nodes.length);
  324. this.show(null);
  325. return null;
  326. }
  327. }
  328. }
  329. return this.renderGraph(model, graph).then(() => {
  330. this._model = model;
  331. this._activeGraph = graph;
  332. this.show('default');
  333. return this._model;
  334. }).catch((error) => {
  335. return this.renderGraph(this._model, this._activeGraph).then(() => {
  336. this.show('default');
  337. throw error;
  338. }).catch(() => {
  339. throw error;
  340. });
  341. });
  342. });
  343. }
  344. renderGraph(model, graph) {
  345. try {
  346. const graphElement = this._host.document.getElementById('canvas');
  347. while (graphElement.lastChild) {
  348. graphElement.removeChild(graphElement.lastChild);
  349. }
  350. if (!graph) {
  351. return Promise.resolve();
  352. }
  353. else {
  354. switch (this._host.environment('zoom')) {
  355. case 'scroll':
  356. this._zoom = 0;
  357. graphElement.style.position = 'static';
  358. graphElement.style.margin = 'auto';
  359. break;
  360. case 'd3':
  361. this._zoom = null;
  362. graphElement.style.position = 'absolute';
  363. graphElement.style.margin = '0';
  364. break;
  365. }
  366. const groups = graph.groups;
  367. const graphOptions = {};
  368. graphOptions.nodesep = 25;
  369. graphOptions.ranksep = 20;
  370. const rotate = graph.nodes.every((node) => node.inputs.filter((input) => input.arguments.every((argument) => !argument.initializer)).length === 0 && node.outputs.length === 0);
  371. const showHorizontal = rotate ? !this._showHorizontal : this._showHorizontal;
  372. if (showHorizontal) {
  373. graphOptions.rankdir = "LR";
  374. }
  375. const g = new dagre.graphlib.Graph({ compound: groups });
  376. g.setGraph(graphOptions);
  377. g.setDefaultEdgeLabel(() => { return {}; });
  378. let nodeId = 0;
  379. const edgeMap = {};
  380. const clusterMap = {};
  381. const clusterParentMap = {};
  382. let id = new Date().getTime();
  383. const nodes = graph.nodes;
  384. if (nodes.length > 1500) {
  385. graphOptions.ranker = 'longest-path';
  386. }
  387. this._host.event('Graph', 'Render', 'Size', nodes.length);
  388. if (groups) {
  389. for (const node of nodes) {
  390. if (node.group) {
  391. const path = node.group.split('/');
  392. while (path.length > 0) {
  393. const name = path.join('/');
  394. path.pop();
  395. clusterParentMap[name] = path.join('/');
  396. }
  397. }
  398. }
  399. }
  400. const self = this;
  401. for (const node of nodes) {
  402. const element = new grapher.NodeElement(this._host.document);
  403. const addNode = function(element, node, edges) {
  404. const header = element.block('header');
  405. const styles = [ 'node-item-type' ];
  406. const metadata = node.metadata;
  407. const category = metadata && metadata.category ? metadata.category : '';
  408. if (category) {
  409. styles.push('node-item-type-' + category.toLowerCase());
  410. }
  411. const type = node.type;
  412. if (typeof type !== 'string' || !type.split) { // #416
  413. throw new ModelError("Unknown node type '" + JSON.stringify(type) + "' in '" + model.format + "'.");
  414. }
  415. const content = self.showNames && node.name ? node.name : type.split('.').pop();
  416. const tooltip = self.showNames && node.name ? type : node.name;
  417. header.add(null, styles, content, tooltip, () => {
  418. self.showNodeProperties(node, null);
  419. });
  420. if (node.function) {
  421. header.add(null, [ 'node-item-function' ], '+', null, () => {
  422. // debugger;
  423. });
  424. }
  425. const initializers = [];
  426. let hiddenInitializers = false;
  427. if (self._showInitializers) {
  428. for (const input of node.inputs) {
  429. if (input.visible && input.arguments.length == 1 && input.arguments[0].initializer != null) {
  430. initializers.push(input);
  431. }
  432. if ((!input.visible || input.arguments.length > 1) &&
  433. input.arguments.some((argument) => argument.initializer != null)) {
  434. hiddenInitializers = true;
  435. }
  436. }
  437. }
  438. let sortedAttributes = [];
  439. const attributes = node.attributes;
  440. if (self.showAttributes && attributes) {
  441. sortedAttributes = attributes.filter((attribute) => attribute.visible).slice();
  442. sortedAttributes.sort((a, b) => {
  443. const au = a.name.toUpperCase();
  444. const bu = b.name.toUpperCase();
  445. return (au < bu) ? -1 : (au > bu) ? 1 : 0;
  446. });
  447. }
  448. if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) {
  449. const block = element.block('list');
  450. block.handler = () => {
  451. self.showNodeProperties(node);
  452. };
  453. for (const initializer of initializers) {
  454. const argument = initializer.arguments[0];
  455. const type = argument.type;
  456. let shape = '';
  457. let separator = '';
  458. if (type &&
  459. type.shape &&
  460. type.shape.dimensions &&
  461. Object.prototype.hasOwnProperty.call(type.shape.dimensions, 'length')) {
  462. shape = '\u3008' + type.shape.dimensions.map((d) => d ? d : '?').join('\u00D7') + '\u3009';
  463. if (type.shape.dimensions.length == 0 && argument.initializer && !argument.initializer.state) {
  464. shape = argument.initializer.toString();
  465. if (shape && shape.length > 10) {
  466. shape = shape.substring(0, 10) + '\u2026';
  467. }
  468. separator = ' = ';
  469. }
  470. }
  471. block.add('initializer-' + argument.name, initializer.name, shape, type ? type.toString() : '', separator);
  472. }
  473. if (hiddenInitializers) {
  474. block.add(null, '\u3008' + '\u2026' + '\u3009', '', null, '');
  475. }
  476. for (const attribute of sortedAttributes) {
  477. if (attribute.visible) {
  478. let attributeValue = sidebar.NodeSidebar.formatAttributeValue(attribute.value, attribute.type);
  479. if (attributeValue && attributeValue.length > 25) {
  480. attributeValue = attributeValue.substring(0, 25) + '\u2026';
  481. }
  482. block.add(null, attribute.name, attributeValue, attribute.type, ' = ');
  483. }
  484. }
  485. }
  486. if (edges) {
  487. const inputs = node.inputs;
  488. for (const input of inputs) {
  489. for (const argument of input.arguments) {
  490. if (argument.name != '' && !argument.initializer) {
  491. let tuple = edgeMap[argument.name];
  492. if (!tuple) {
  493. tuple = { from: null, to: [] };
  494. edgeMap[argument.name] = tuple;
  495. }
  496. tuple.to.push({
  497. node: nodeId,
  498. name: input.name
  499. });
  500. }
  501. }
  502. }
  503. let outputs = node.outputs;
  504. if (node.chain && node.chain.length > 0) {
  505. const chainOutputs = node.chain[node.chain.length - 1].outputs;
  506. if (chainOutputs.length > 0) {
  507. outputs = chainOutputs;
  508. }
  509. }
  510. for (const output of outputs) {
  511. for (const argument of output.arguments) {
  512. if (argument.name != '') {
  513. let tuple = edgeMap[argument.name];
  514. if (!tuple) {
  515. tuple = { from: null, to: [] };
  516. edgeMap[argument.name] = tuple;
  517. }
  518. tuple.from = {
  519. node: nodeId,
  520. name: output.name,
  521. type: argument.type
  522. };
  523. }
  524. }
  525. }
  526. }
  527. if (node.chain && node.chain.length > 0) {
  528. for (const innerNode of node.chain) {
  529. addNode(element, innerNode, false);
  530. }
  531. }
  532. if (node.inner) {
  533. addNode(element, node.inner, false);
  534. }
  535. };
  536. addNode(element, node, true);
  537. if (node.controlDependencies && node.controlDependencies.length > 0) {
  538. for (const controlDependency of node.controlDependencies) {
  539. let tuple = edgeMap[controlDependency];
  540. if (!tuple) {
  541. tuple = { from: null, to: [] };
  542. edgeMap[controlDependency] = tuple;
  543. }
  544. tuple.to.push({
  545. node: nodeId,
  546. name: controlDependency,
  547. controlDependency: true
  548. });
  549. }
  550. }
  551. const nodeName = node.name;
  552. if (nodeName) {
  553. g.setNode(nodeId, { label: element.format(graphElement), id: 'node-' + nodeName, class: 'graph-node' });
  554. }
  555. else {
  556. g.setNode(nodeId, { label: element.format(graphElement), id: 'node-' + id.toString(), class: 'graph-node' });
  557. id++;
  558. }
  559. const createCluster = function(name) {
  560. if (!clusterMap[name]) {
  561. g.setNode(name, { rx: 5, ry: 5});
  562. clusterMap[name] = true;
  563. const parent = clusterParentMap[name];
  564. if (parent) {
  565. createCluster(parent);
  566. g.setParent(name, parent);
  567. }
  568. }
  569. };
  570. if (groups) {
  571. let groupName = node.group;
  572. if (groupName && groupName.length > 0) {
  573. if (!Object.prototype.hasOwnProperty.call(clusterParentMap, groupName)) {
  574. const lastIndex = groupName.lastIndexOf('/');
  575. if (lastIndex != -1) {
  576. groupName = groupName.substring(0, lastIndex);
  577. if (!Object.prototype.hasOwnProperty.call(clusterParentMap, groupName)) {
  578. groupName = null;
  579. }
  580. }
  581. else {
  582. groupName = null;
  583. }
  584. }
  585. if (groupName) {
  586. createCluster(groupName);
  587. g.setParent(nodeId, groupName);
  588. }
  589. }
  590. }
  591. nodeId++;
  592. }
  593. for (const input of graph.inputs) {
  594. for (const argument of input.arguments) {
  595. let tuple = edgeMap[argument.name];
  596. if (!tuple) {
  597. tuple = { from: null, to: [] };
  598. edgeMap[argument.name] = tuple;
  599. }
  600. tuple.from = {
  601. node: nodeId,
  602. type: argument.type
  603. };
  604. }
  605. const types = input.arguments.map((argument) => argument.type || '').join('\n');
  606. let inputName = input.name || '';
  607. if (inputName.length > 16) {
  608. inputName = inputName.split('/').pop();
  609. }
  610. const inputElement = new grapher.NodeElement(this._host.document);
  611. const inputHeader = inputElement.block('header');
  612. inputHeader.add(null, [ 'graph-item-input' ], inputName, types, () => {
  613. this.showModelProperties();
  614. });
  615. g.setNode(nodeId++, { label: inputElement.format(graphElement), class: 'graph-input' } );
  616. }
  617. for (const output of graph.outputs) {
  618. for (const argument of output.arguments) {
  619. let tuple = edgeMap[argument.name];
  620. if (!tuple) {
  621. tuple = { from: null, to: [] };
  622. edgeMap[argument.name] = tuple;
  623. }
  624. tuple.to.push({ node: nodeId });
  625. }
  626. const outputTypes = output.arguments.map((argument) => argument.type || '').join('\n');
  627. let outputName = output.name || '';
  628. if (outputName.length > 16) {
  629. outputName = outputName.split('/').pop();
  630. }
  631. const outputElement = new grapher.NodeElement(this._host.document);
  632. const outputHeader = outputElement.block('header');
  633. outputHeader.add(null, [ 'graph-item-output' ], outputName, outputTypes, () => {
  634. this.showModelProperties();
  635. });
  636. g.setNode(nodeId++, { label: outputElement.format(graphElement) } );
  637. }
  638. for (const edge of Object.keys(edgeMap)) {
  639. const tuple = edgeMap[edge];
  640. if (tuple.from != null) {
  641. for (const to of tuple.to) {
  642. let text = '';
  643. const type = tuple.from.type;
  644. if (type && type.shape && type.shape.dimensions && type.shape.dimensions.length > 0) {
  645. text = type.shape.dimensions.join('\u00D7');
  646. }
  647. if (this._showNames) {
  648. text = edge.split('\n').shift(); // custom argument id
  649. }
  650. if (to.controlDependency) {
  651. g.setEdge(tuple.from.node, to.node, { label: text, id: 'edge-' + edge, arrowhead: 'vee', class: 'edge-path-control-dependency' } );
  652. }
  653. else {
  654. g.setEdge(tuple.from.node, to.node, { label: text, id: 'edge-' + edge, arrowhead: 'vee' } );
  655. }
  656. }
  657. }
  658. }
  659. // Workaround for Safari background drag/zoom issue:
  660. // https://stackoverflow.com/questions/40887193/d3-js-zoom-is-not-working-with-mousewheel-in-safari
  661. const backgroundElement = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'rect');
  662. backgroundElement.setAttribute('id', 'background');
  663. if (this._host.environment('zoom') == 'd3') {
  664. backgroundElement.setAttribute('width', '100%');
  665. backgroundElement.setAttribute('height', '100%');
  666. }
  667. backgroundElement.setAttribute('fill', 'none');
  668. backgroundElement.setAttribute('pointer-events', 'all');
  669. graphElement.appendChild(backgroundElement);
  670. const originElement = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'g');
  671. originElement.setAttribute('id', 'origin');
  672. graphElement.appendChild(originElement);
  673. let svg = null;
  674. if (this._host.environment('zoom') == 'd3') {
  675. svg = d3.select(graphElement);
  676. this._zoom = d3.zoom();
  677. this._zoom(svg);
  678. this._zoom.scaleExtent([0.1, 2]);
  679. this._zoom.on('zoom', () => {
  680. originElement.setAttribute('transform', d3.event.transform.toString());
  681. });
  682. this._zoom.transform(svg, d3.zoomIdentity);
  683. }
  684. return this._timeout(20).then(() => {
  685. const graphRenderer = new grapher.Renderer(this._host.document, originElement);
  686. graphRenderer.render(g);
  687. const originElements = Array.from(graphElement.getElementsByClassName('graph-input') || []);
  688. if (originElements.length === 0) {
  689. const nodeElements = Array.from(graphElement.getElementsByClassName('graph-node') || []);
  690. if (nodeElements.length > 0) {
  691. originElements.push(nodeElements[0]);
  692. }
  693. }
  694. switch (this._host.environment('zoom')) {
  695. case 'scroll': {
  696. const size = graphElement.getBBox();
  697. const margin = 100;
  698. const width = Math.ceil(margin + size.width + margin);
  699. const height = Math.ceil(margin + size.height + margin);
  700. originElement.setAttribute('transform', 'translate(' + margin.toString() + ', ' + margin.toString() + ') scale(1)');
  701. backgroundElement.setAttribute('width', width);
  702. backgroundElement.setAttribute('height', height);
  703. this._width = width;
  704. this._height = height;
  705. this._zoom = 1;
  706. delete this._scrollLeft;
  707. delete this._scrollRight;
  708. graphElement.setAttribute('viewBox', '0 0 ' + width + ' ' + height);
  709. graphElement.setAttribute('width', width);
  710. graphElement.setAttribute('height', height);
  711. if (originElements && originElements.length > 0) {
  712. // Center view based on input elements
  713. for (let j = 0; j < originElements.length; j++) {
  714. originElements[j].scrollIntoView({ behavior: 'instant' });
  715. break;
  716. }
  717. }
  718. else {
  719. // this._zoom.transform(svg, d3.zoomIdentity.translate((svgSize.width - g.graph().width) / 2, (svgSize.height - g.graph().height) / 2));
  720. }
  721. break;
  722. }
  723. case 'd3': {
  724. const svgSize = graphElement.getBoundingClientRect();
  725. if (originElements && originElements.length > 0) {
  726. // Center view based on input elements
  727. const xs = [];
  728. const ys = [];
  729. for (let i = 0; i < originElements.length; i++) {
  730. const inputTransform = originElements[i].transform.baseVal.consolidate().matrix;
  731. xs.push(inputTransform.e);
  732. ys.push(inputTransform.f);
  733. }
  734. let x = xs[0];
  735. const y = ys[0];
  736. if (ys.every(y => y == ys[0])) {
  737. x = xs.reduce((a,b) => { return a + b; }) / xs.length;
  738. }
  739. const sx = (svgSize.width / (this._showHorizontal ? 4 : 2)) - x;
  740. const sy = (svgSize.height / (this._showHorizontal ? 2 : 4)) - y;
  741. this._zoom.transform(svg, d3.zoomIdentity.translate(sx, sy));
  742. }
  743. else {
  744. this._zoom.transform(svg, d3.zoomIdentity.translate((svgSize.width - g.graph().width) / 2, (svgSize.height - g.graph().height) / 2));
  745. }
  746. break;
  747. }
  748. }
  749. return;
  750. });
  751. }
  752. }
  753. catch (error) {
  754. return Promise.reject(error);
  755. }
  756. }
  757. applyStyleSheet(element, name) {
  758. let rules = [];
  759. for (let i = 0; i < this._host.document.styleSheets.length; i++) {
  760. const styleSheet = this._host.document.styleSheets[i];
  761. if (styleSheet && styleSheet.href && styleSheet.href.endsWith('/' + name)) {
  762. rules = styleSheet.cssRules;
  763. break;
  764. }
  765. }
  766. const nodes = element.getElementsByTagName('*');
  767. for (let j = 0; j < nodes.length; j++) {
  768. const node = nodes[j];
  769. for (let k = 0; k < rules.length; k++) {
  770. const rule = rules[k];
  771. if (node.matches(rule.selectorText)) {
  772. for (let l = 0; l < rule.style.length; l++) {
  773. const item = rule.style.item(l);
  774. node.style[item] = rule.style[item];
  775. }
  776. }
  777. }
  778. }
  779. }
  780. export(file) {
  781. const lastIndex = file.lastIndexOf('.');
  782. const extension = (lastIndex != -1) ? file.substring(lastIndex + 1) : '';
  783. if (this._activeGraph && (extension == 'png' || extension == 'svg')) {
  784. const graphElement = this._host.document.getElementById('canvas');
  785. const exportElement = graphElement.cloneNode(true);
  786. this.applyStyleSheet(exportElement, 'view-grapher.css');
  787. exportElement.setAttribute('id', 'export');
  788. exportElement.removeAttribute('width');
  789. exportElement.removeAttribute('height');
  790. exportElement.style.removeProperty('opacity');
  791. exportElement.style.removeProperty('display');
  792. const backgroundElement = exportElement.querySelector('#background');
  793. const originElement = exportElement.querySelector('#origin');
  794. originElement.setAttribute('transform', 'translate(0,0) scale(1)');
  795. backgroundElement.removeAttribute('width');
  796. backgroundElement.removeAttribute('height');
  797. const parentElement = graphElement.parentElement;
  798. parentElement.insertBefore(exportElement, graphElement);
  799. const size = exportElement.getBBox();
  800. parentElement.removeChild(exportElement);
  801. parentElement.removeChild(graphElement);
  802. parentElement.appendChild(graphElement);
  803. const delta = (Math.min(size.width, size.height) / 2.0) * 0.1;
  804. const width = Math.ceil(delta + size.width + delta);
  805. const height = Math.ceil(delta + size.height + delta);
  806. originElement.setAttribute('transform', 'translate(' + delta.toString() + ', ' + delta.toString() + ') scale(1)');
  807. exportElement.setAttribute('width', width);
  808. exportElement.setAttribute('height', height);
  809. backgroundElement.setAttribute('width', width);
  810. backgroundElement.setAttribute('height', height);
  811. backgroundElement.setAttribute('fill', '#fff');
  812. const data = new XMLSerializer().serializeToString(exportElement);
  813. if (extension == 'svg') {
  814. const blob = new Blob([ data ], { type: 'image/svg' });
  815. this._host.export(file, blob);
  816. }
  817. if (extension == 'png') {
  818. const imageElement = new Image();
  819. imageElement.onload = () => {
  820. const max = Math.max(width, height);
  821. const scale = ((max * 2.0) > 24000) ? (24000.0 / max) : 2.0;
  822. const canvas = this._host.document.createElement('canvas');
  823. canvas.width = Math.ceil(width * scale);
  824. canvas.height = Math.ceil(height * scale);
  825. const context = canvas.getContext('2d');
  826. context.scale(scale, scale);
  827. context.drawImage(imageElement, 0, 0);
  828. this._host.document.body.removeChild(imageElement);
  829. canvas.toBlob((blob) => {
  830. if (blob) {
  831. this._host.export(file, blob);
  832. }
  833. else {
  834. const err = new Error();
  835. err.name = 'Error exporting image.';
  836. err.message = 'Image may be too large to render as PNG.';
  837. this._host.exception(err, false);
  838. this._host.error(err.name, err.message);
  839. }
  840. }, 'image/png');
  841. };
  842. imageElement.src = 'data:image/svg+xml;base64,' + window.btoa(unescape(encodeURIComponent(data)));
  843. this._host.document.body.insertBefore(imageElement, this._host.document.body.firstChild);
  844. }
  845. }
  846. }
  847. showModelProperties() {
  848. if (this._model) {
  849. const modelSidebar = new sidebar.ModelSidebar(this._host, this._model, this._activeGraph);
  850. modelSidebar.on('update-active-graph', (sender, name) => {
  851. this._updateActiveGraph(name);
  852. });
  853. this._sidebar.open(modelSidebar.render(), 'Model Properties');
  854. }
  855. }
  856. showNodeProperties(node, input) {
  857. if (node) {
  858. const nodeSidebar = new sidebar.NodeSidebar(this._host, node);
  859. nodeSidebar.on('show-documentation', (/* sender, e */) => {
  860. this.showNodeDocumentation(node);
  861. });
  862. nodeSidebar.on('export-tensor', (sender, tensor) => {
  863. this._host.require('./numpy').then((numpy) => {
  864. const defaultPath = tensor.name ? tensor.name.split('/').join('_').split(':').join('_').split('.').join('_') : 'tensor';
  865. this._host.save('NumPy Array', 'npy', defaultPath, (file) => {
  866. try {
  867. const dataTypeMap = new Map([
  868. [ 'float16', 'f2' ], [ 'float32', 'f4' ], [ 'float64', 'f8' ],
  869. [ 'int8', 'i1' ], [ 'int16', 'i2'], [ 'int32', 'i4' ], [ 'int64', 'i8' ],
  870. [ 'uint8', 'u1' ], [ 'uint16', 'u2' ], [ 'uint32', 'u4' ], [ 'uint64', 'u8' ],
  871. [ 'qint8', 'i1' ], [ 'qint16', 'i2' ],
  872. [ 'quint8', 'u1' ], [ 'quint16', 'u2' ]
  873. ]);
  874. const array = new numpy.Array();
  875. array.shape = tensor.type.shape.dimensions;
  876. array.data = tensor.value;
  877. array.dataType = dataTypeMap.has(tensor.type.dataType) ? dataTypeMap.get(tensor.type.dataType) : tensor.type.dataType;
  878. const blob = new Blob([ array.toBuffer() ], { type: 'application/octet-stream' });
  879. this._host.export(file, blob);
  880. }
  881. catch (error) {
  882. this.error(error, 'Error saving NumPy tensor.', null);
  883. }
  884. });
  885. }).catch(() => {
  886. });
  887. });
  888. if (input) {
  889. nodeSidebar.toggleInput(input.name);
  890. }
  891. this._sidebar.open(nodeSidebar.render(), 'Node Properties');
  892. }
  893. }
  894. showNodeDocumentation(node) {
  895. const metadata = node.metadata;
  896. if (metadata) {
  897. const documentationSidebar = new sidebar.DocumentationSidebar(this._host, metadata);
  898. documentationSidebar.on('navigate', (sender, e) => {
  899. this._host.openURL(e.link);
  900. });
  901. this._sidebar.push(documentationSidebar.render(), 'Documentation');
  902. }
  903. }
  904. };
  905. class ModelError extends Error {
  906. constructor(message, telemetry) {
  907. super(message);
  908. this.name = 'Error loading model.';
  909. this.telemetry = telemetry;
  910. }
  911. }
  912. class ModelContext {
  913. constructor(context) {
  914. this._context = context;
  915. this._tags = new Map();
  916. this._entries = new Map();
  917. }
  918. request(file, encoding) {
  919. return this._context.request(file, encoding);
  920. }
  921. get identifier() {
  922. return this._context.identifier;
  923. }
  924. get buffer() {
  925. return this._context.buffer;
  926. }
  927. get text() {
  928. if (!this._text) {
  929. this._text = new TextDecoder('utf-8', { fatal: true }).decode(this.buffer);
  930. }
  931. return this._text;
  932. }
  933. entries(extension) {
  934. let entries = this._entries.get(extension);
  935. if (!entries) {
  936. entries = [];
  937. try {
  938. const buffer = this.buffer;
  939. switch (extension) {
  940. case 'zip': {
  941. if (buffer && buffer.length > 2 && buffer[0] == 0x50 && buffer[1] == 0x4B) {
  942. entries = new zip.Archive(buffer).entries;
  943. }
  944. break;
  945. }
  946. case 'tar': {
  947. if (buffer.length >= 512) {
  948. let sum = 0;
  949. for (let i = 0; i < 512; i++) {
  950. sum += (i >= 148 && i < 156) ? 32 : buffer[i];
  951. }
  952. let checksum = '';
  953. for (let i = 148; i < 156 && buffer[i] !== 0x00; i++) {
  954. checksum += String.fromCharCode(buffer[i]);
  955. }
  956. checksum = parseInt(checksum, 8);
  957. if (!isNaN(checksum) && sum == checksum) {
  958. entries = new tar.Archive(buffer).entries;
  959. }
  960. }
  961. break;
  962. }
  963. }
  964. }
  965. catch (error) {
  966. entries = [];
  967. }
  968. this._entries.set(extension, entries);
  969. }
  970. return entries;
  971. }
  972. tags(extension) {
  973. let tags = this._tags.get(extension);
  974. if (!tags) {
  975. tags = new Map();
  976. try {
  977. switch (extension) {
  978. case 'pbtxt': {
  979. const b = this.buffer;
  980. const length = b.length;
  981. const signature =
  982. (length >= 3 && b[0] === 0xef && b[1] === 0xbb && b[2] === 0xbf) ||
  983. (length >= 4 && b[0] === 0x00 && b[1] === 0x00 && b[2] === 0xfe && b[3] === 0xff) ||
  984. (length >= 4 && b[0] === 0xff && b[1] === 0xfe && b[2] === 0x00 && b[3] === 0x00) ||
  985. (length >= 4 && b[0] === 0x84 && b[1] === 0x31 && b[2] === 0x95 && b[3] === 0x33) ||
  986. (length >= 2 && b[0] === 0xfe && b[1] === 0xff) ||
  987. (length >= 2 && b[0] === 0xff && b[1] === 0xfe);
  988. if (!signature && b.subarray(0, Math.min(1024, length)).some((c) => c < 7 || (c > 14 && c < 32))) {
  989. break;
  990. }
  991. const reader = protobuf.TextReader.create(this.buffer);
  992. reader.start(false);
  993. while (!reader.end(false)) {
  994. const tag = reader.tag();
  995. tags.set(tag, true);
  996. reader.skip();
  997. }
  998. break;
  999. }
  1000. case 'pb': {
  1001. const tagTypes = new Set([ 0, 1, 2, 3, 5 ]);
  1002. const reader = protobuf.Reader.create(this.buffer);
  1003. const end = reader.next();
  1004. while (reader.pos < end) {
  1005. const tagType = reader.uint32();
  1006. tags.set(tagType >>> 3, tagType & 7);
  1007. if (!tagTypes.has(tagType & 7)) {
  1008. tags = new Map();
  1009. break;
  1010. }
  1011. try {
  1012. reader.skipType(tagType & 7);
  1013. }
  1014. catch (err) {
  1015. tags = new Map();
  1016. break;
  1017. }
  1018. }
  1019. break;
  1020. }
  1021. }
  1022. }
  1023. catch (error) {
  1024. tags = new Map();
  1025. }
  1026. this._tags.set(extension, tags);
  1027. }
  1028. return tags;
  1029. }
  1030. }
  1031. class ArchiveContext {
  1032. constructor(entries, rootFolder, identifier, buffer) {
  1033. this._entries = {};
  1034. if (entries) {
  1035. for (const entry of entries) {
  1036. if (entry.name.startsWith(rootFolder)) {
  1037. const name = entry.name.substring(rootFolder.length);
  1038. if (name.length > 0 && name.indexOf('/') === -1) {
  1039. this._entries[name] = entry;
  1040. }
  1041. }
  1042. }
  1043. }
  1044. this._identifier = identifier.substring(rootFolder.length);
  1045. this._buffer = buffer;
  1046. }
  1047. request(file, encoding) {
  1048. const entry = this._entries[file];
  1049. if (!entry) {
  1050. return Promise.reject(new Error('File not found.'));
  1051. }
  1052. const data = encoding ? new TextDecoder(encoding).decode(entry.data) : entry.data;
  1053. return Promise.resolve(data);
  1054. }
  1055. get identifier() {
  1056. return this._identifier;
  1057. }
  1058. get buffer() {
  1059. return this._buffer;
  1060. }
  1061. }
  1062. class ArchiveError extends Error {
  1063. constructor(message) {
  1064. super(message);
  1065. this.name = 'Error loading archive.';
  1066. }
  1067. }
  1068. view.ModelFactoryService = class {
  1069. constructor(host) {
  1070. this._host = host;
  1071. this._extensions = [];
  1072. this.register('./onnx', [ '.onnx', '.pb', '.pbtxt', '.prototxt', '.model' ]);
  1073. this.register('./mxnet', [ '.mar', '.model', '.json', '.params' ]);
  1074. this.register('./keras', [ '.h5', '.hd5', '.hdf5', '.keras', '.json', '.model', '.pb', '.pth' ]);
  1075. this.register('./coreml', [ '.mlmodel' ]);
  1076. this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt' ]);
  1077. this.register('./caffe2', [ '.pb', '.pbtxt', '.prototxt' ]);
  1078. this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pkl', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.bin', '.pb', '.zip' ]);
  1079. this.register('./torch', [ '.t7' ]);
  1080. this.register('./tflite', [ '.tflite', '.lite', '.tfl', '.bin', '.pb', '.tmfile', '.h5', '.model', '.json' ]);
  1081. this.register('./tf', [ '.pb', '.meta', '.pbtxt', '.prototxt', '.json', '.index', '.ckpt' ]);
  1082. this.register('./mediapipe', [ '.pbtxt' ]);
  1083. this.register('./uff', [ '.uff', '.pb', '.pbtxt', '.uff.txt', '.trt', '.engine' ]);
  1084. this.register('./sklearn', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5' ]);
  1085. this.register('./cntk', [ '.model', '.cntk', '.cmf', '.dnn' ]);
  1086. this.register('./paddle', [ '.paddle', '.pdmodel', '__model__' ]);
  1087. this.register('./bigdl', [ '.model', '.bigdl' ]);
  1088. this.register('./darknet', [ '.cfg', '.model' ]);
  1089. this.register('./armnn', [ '.armnn', '.json' ]);
  1090. this.register('./mnn', ['.mnn']);
  1091. this.register('./ncnn', [ '.param', '.bin', '.cfg.ncnn', '.weights.ncnn' ]);
  1092. this.register('./tnn', [ '.tnnproto', '.tnnmodel' ]);
  1093. this.register('./tengine', ['.tmfile']);
  1094. this.register('./barracuda', [ '.nn' ]);
  1095. this.register('./openvino', [ '.xml', '.bin' ]);
  1096. this.register('./flux', [ '.bson' ]);
  1097. this.register('./npz', [ '.npz', '.h5', '.hd5', '.hdf5' ]);
  1098. this.register('./dl4j', [ '.zip' ]);
  1099. this.register('./mlnet', [ '.zip' ]);
  1100. }
  1101. register(id, extensions) {
  1102. for (const extension of extensions) {
  1103. this._extensions.push({ extension: extension, id: id });
  1104. }
  1105. }
  1106. open(context) {
  1107. return this._openSignature(context).then((context) => {
  1108. return this._openArchive(context).then((context) => {
  1109. context = new ModelContext(context);
  1110. const identifier = context.identifier;
  1111. const extension = identifier.split('.').pop().toLowerCase();
  1112. const modules = this._filter(context).filter((module) => module && module.length > 0);
  1113. if (modules.length == 0) {
  1114. throw new ModelError("Unsupported file extension '." + extension + "'.");
  1115. }
  1116. const errors = [];
  1117. let match = false;
  1118. const nextModule = () => {
  1119. if (modules.length > 0) {
  1120. const id = modules.shift();
  1121. return this._host.require(id).then((module) => {
  1122. if (!module.ModelFactory) {
  1123. throw new ModelError("Failed to load module '" + id + "'.");
  1124. }
  1125. const modelFactory = new module.ModelFactory();
  1126. if (!modelFactory.match(context)) {
  1127. return nextModule();
  1128. }
  1129. match++;
  1130. return modelFactory.open(context, this._host).then((model) => {
  1131. return model;
  1132. }).catch((error) => {
  1133. errors.push(error);
  1134. return nextModule();
  1135. });
  1136. });
  1137. }
  1138. else {
  1139. if (match) {
  1140. if (errors.length == 1) {
  1141. throw errors[0];
  1142. }
  1143. throw new ModelError(errors.map((err) => err.message).join('\n'));
  1144. }
  1145. const knownUnsupportedIdentifiers = new Set([
  1146. 'natives_blob.bin',
  1147. 'v8_context_snapshot.bin',
  1148. 'snapshot_blob.bin',
  1149. 'image_net_labels.json',
  1150. 'package.json',
  1151. 'models.json',
  1152. 'LICENSE.meta',
  1153. 'input_0.pb',
  1154. 'output_0.pb'
  1155. ]);
  1156. const skip = knownUnsupportedIdentifiers.has(identifier);
  1157. const buffer = context.buffer;
  1158. const bytes = Array.from(buffer.subarray(0, Math.min(16, buffer.length))).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  1159. const content = buffer.length > 268435456 ? '(' + bytes + ') [' + buffer.length.toString() + ']': '(' + bytes + ')';
  1160. throw new ModelError("Unsupported file content " + content + " for extension '." + extension + "' in '" + identifier + "'.", !skip);
  1161. }
  1162. };
  1163. return nextModule();
  1164. });
  1165. });
  1166. }
  1167. _openArchive(context) {
  1168. let archive = null;
  1169. let extension;
  1170. let identifier = context.identifier;
  1171. let buffer = context.buffer;
  1172. try {
  1173. extension = identifier.split('.').pop().toLowerCase();
  1174. if (extension == 'gz' || extension == 'tgz') {
  1175. archive = new gzip.Archive(buffer);
  1176. if (archive.entries.length == 1) {
  1177. const entry = archive.entries[0];
  1178. if (entry.name) {
  1179. identifier = entry.name;
  1180. }
  1181. else {
  1182. identifier = identifier.substring(0, identifier.lastIndexOf('.'));
  1183. if (extension == 'tgz') {
  1184. identifier += '.tar';
  1185. }
  1186. }
  1187. buffer = entry.data;
  1188. }
  1189. }
  1190. }
  1191. catch (error) {
  1192. const message = error && error.message ? error.message : error.toString();
  1193. return Promise.reject(new ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'."));
  1194. }
  1195. try {
  1196. extension = identifier.split('.').pop().toLowerCase();
  1197. switch (extension) {
  1198. case 'tar': {
  1199. // handle .pth.tar
  1200. const torch = [ 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  1201. if (buffer && buffer.length >= 14 && buffer[0] === 0x80 && torch.every((v, i) => v === buffer[i + 2])) {
  1202. break;
  1203. }
  1204. if (buffer && buffer.length >= 4 && buffer[0] === 0x50 && buffer[1] === 0x4B) {
  1205. break;
  1206. }
  1207. archive = new tar.Archive(buffer);
  1208. break;
  1209. }
  1210. case 'zip': {
  1211. archive = new zip.Archive(buffer);
  1212. // PyTorch Zip archive
  1213. if (archive.entries.some((e) => e.name.split('/').pop().split('\\').pop() === 'version') &&
  1214. archive.entries.some((e) => e.name.split('/').pop().split('\\').pop() === 'data.pkl')) {
  1215. return Promise.resolve(context);
  1216. }
  1217. // dl4j
  1218. if (archive.entries.some((e) => e.name.split('/').pop().split('\\').pop() === 'coefficients.bin') &&
  1219. archive.entries.some((e) => e.name.split('/').pop().split('\\').pop() === 'configuration.json')) {
  1220. return Promise.resolve(context);
  1221. }
  1222. break;
  1223. }
  1224. }
  1225. }
  1226. catch (error) {
  1227. const message = error && error.message ? error.message : error.toString();
  1228. return Promise.reject(new ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'."));
  1229. }
  1230. if (!archive) {
  1231. return Promise.resolve(context);
  1232. }
  1233. try {
  1234. const folders = {};
  1235. const entries = archive.entries.filter((entry) => !entry.name.endsWith('/') && !entry.name.split('/').pop().startsWith('.')).slice();
  1236. for (const entry of entries) {
  1237. if (entry.name.indexOf('/') != -1) {
  1238. folders[entry.name.split('/').shift() + '/'] = true;
  1239. }
  1240. else {
  1241. folders['/'] = true;
  1242. }
  1243. }
  1244. if (extension == 'tar') {
  1245. delete folders['PaxHeader/'];
  1246. }
  1247. let rootFolder = Object.keys(folders).length == 1 ? Object.keys(folders)[0] : '';
  1248. rootFolder = rootFolder == '/' ? '' : rootFolder;
  1249. let matches = [];
  1250. const queue = entries.slice(0);
  1251. const nextEntry = () => {
  1252. if (queue.length > 0) {
  1253. const entry = queue.shift();
  1254. if (entry.name.startsWith(rootFolder)) {
  1255. const identifier = entry.name.substring(rootFolder.length);
  1256. if (identifier.length > 0 && identifier.indexOf('/') < 0 && !identifier.startsWith('.')) {
  1257. const context = new ModelContext(new ArchiveContext(null, rootFolder, entry.name, entry.data));
  1258. let modules = this._filter(context);
  1259. const nextModule = () => {
  1260. if (modules.length > 0) {
  1261. const id = modules.shift();
  1262. return this._host.require(id).then((module) => {
  1263. if (!module.ModelFactory) {
  1264. throw new ArchiveError("Failed to load module '" + id + "'.", null);
  1265. }
  1266. const factory = new module.ModelFactory();
  1267. if (factory.match(context)) {
  1268. matches.push(entry);
  1269. modules = [];
  1270. }
  1271. return nextModule();
  1272. });
  1273. }
  1274. else {
  1275. return nextEntry();
  1276. }
  1277. };
  1278. return nextModule();
  1279. }
  1280. }
  1281. return nextEntry();
  1282. }
  1283. else {
  1284. if (matches.length == 0) {
  1285. return Promise.resolve(context);
  1286. }
  1287. // MXNet
  1288. if (matches.length == 2 &&
  1289. matches.some((e) => e.name.endsWith('.params')) &&
  1290. matches.some((e) => e.name.endsWith('-symbol.json'))) {
  1291. matches = matches.filter((e) => e.name.endsWith('.params'));
  1292. }
  1293. if (matches.length > 1) {
  1294. return Promise.reject(new ArchiveError('Archive contains multiple model files.'));
  1295. }
  1296. const match = matches[0];
  1297. return Promise.resolve(new ModelContext(new ArchiveContext(entries, rootFolder, match.name, match.data)));
  1298. }
  1299. };
  1300. return nextEntry();
  1301. }
  1302. catch (error) {
  1303. return Promise.reject(new ArchiveError(error.message));
  1304. }
  1305. }
  1306. accept(identifier) {
  1307. const extension = identifier.split('.').pop().toLowerCase();
  1308. identifier = identifier.toLowerCase();
  1309. for (const entry of this._extensions) {
  1310. if (identifier.endsWith(entry.extension)) {
  1311. this._host.event('File', 'Accept', extension, 1);
  1312. return true;
  1313. }
  1314. }
  1315. if (identifier.endsWith('.zip') ||
  1316. identifier.endsWith('.tar') ||
  1317. identifier.endsWith('.tar.gz') ||
  1318. identifier.endsWith('.tgz')) {
  1319. this._host.event('File', 'Accept', extension, 1);
  1320. return true;
  1321. }
  1322. this._host.event('File', 'Reject', extension, 1);
  1323. return false;
  1324. }
  1325. _filter(context) {
  1326. const identifier = context.identifier.toLowerCase();
  1327. const list = this._extensions.filter((entry) => identifier.endsWith(entry.extension)).map((extry) => extry.id);
  1328. return Array.from(new Set(list));
  1329. }
  1330. _openSignature(context) {
  1331. const buffer = context.buffer;
  1332. if (context.buffer.length === 0) {
  1333. return Promise.reject(new ModelError("File has no content.", true));
  1334. }
  1335. const list = [
  1336. { name: 'ELF executable', value: /^\x7FELF/ },
  1337. { name: 'Git LFS header', value: /^version https:\/\/git-lfs.github.com\/spec\/v1\n/ },
  1338. { name: 'Git LFS header', value: /^oid sha256:/ },
  1339. { name: 'HTML markup', value: /^\s*<html>/ },
  1340. { name: 'HTML markup', value: /^\s*<!DOCTYPE html>/ },
  1341. { name: 'HTML markup', value: /^\s*<!DOCTYPE HTML>/ },
  1342. { name: 'Unity metadata', value: /^fileFormatVersion:/ },
  1343. { name: 'Vulkan SwiftShader ICD manifest', value: /^{\s*"file_format_version":\s*"1.0.0"\s*,\s*"ICD":/ },
  1344. { name: 'StringIntLabelMapProto data', value: /^(#.*\n)*item\s*{\r?\n\s*id:/ },
  1345. { name: 'StringIntLabelMapProto data', value: /^(#.*\n)*item\s*{\r?\n\s*name:/ },
  1346. { name: 'ImageNet LabelMap data', value: /^(#.*\n)*entry\s*{\r?\n\s*target_class/ },
  1347. { name: 'Python source code', value: /^\s*import sys, types, os;/ },
  1348. { name: 'undocumented TensorRT engine data', value: /^ptrt/ },
  1349. { name: 'TSD header', value: /^%TSD-Header-###%/ },
  1350. { name: 'Darkflow metadata', value: /^{"net":\s*{"type":/ }
  1351. ];
  1352. const text = new TextDecoder().decode(buffer.subarray(0, Math.min(1024, buffer.length)));
  1353. for (const item of list) {
  1354. if (text.match(item.value)) {
  1355. return Promise.reject(new ModelError("Invalid file content. File contains " + item.name + ".", true));
  1356. }
  1357. }
  1358. return Promise.resolve(context);
  1359. }
  1360. };
  1361. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1362. module.exports.View = view.View;
  1363. module.exports.ModelFactoryService = view.ModelFactoryService;
  1364. }