view.js 95 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131
  1. var view = view || {};
  2. var base = base || require('./base');
  3. var zip = zip || require('./zip');
  4. var gzip = gzip || require('./gzip');
  5. var tar = tar || require('./tar');
  6. var json = json || require('./json');
  7. var xml = xml || require('./xml');
  8. var protobuf = protobuf || require('./protobuf');
  9. var flatbuffers = flatbuffers || require('./flatbuffers');
  10. var python = python || require('./python');
  11. var sidebar = sidebar || require('./view-sidebar');
  12. var grapher = grapher || require('./view-grapher');
  13. view.View = class {
  14. constructor(host, id) {
  15. this._host = host;
  16. this._id = id ? ('-' + id) : '';
  17. this._options = {
  18. initializers: true,
  19. attributes: false,
  20. names: false,
  21. direction: 'vertical',
  22. mousewheel: 'scroll'
  23. };
  24. this._host.initialize(this).then(() => {
  25. this._model = null;
  26. this._graphs = [];
  27. this._selection = [];
  28. this._sidebar = new sidebar.Sidebar(this._host, id);
  29. this._searchText = '';
  30. this._modelFactoryService = new view.ModelFactoryService(this._host);
  31. this._getElementById('zoom-in-button').addEventListener('click', () => {
  32. this.zoomIn();
  33. });
  34. this._getElementById('zoom-out-button').addEventListener('click', () => {
  35. this.zoomOut();
  36. });
  37. this._getElementById('back-button').addEventListener('click', () => {
  38. this.popGraph();
  39. });
  40. this._getElementById('name-button').addEventListener('click', () => {
  41. this.showDocumentation(this.activeGraph);
  42. });
  43. this._getElementById('sidebar').addEventListener('mousewheel', (e) => {
  44. this._preventDefault(e);
  45. }, { passive: true });
  46. this._host.document.addEventListener('keydown', () => {
  47. this.clearSelection();
  48. });
  49. this._host.start();
  50. const container = this._getElementById('graph');
  51. container.addEventListener('scroll', (e) => this._scrollHandler(e));
  52. container.addEventListener('wheel', (e) => this._wheelHandler(e), { passive: false });
  53. container.addEventListener('mousedown', (e) => this._mouseDownHandler(e));
  54. switch (this._host.agent) {
  55. case 'safari':
  56. container.addEventListener('gesturestart', (e) => this._gestureStartHandler(e), false);
  57. break;
  58. default:
  59. container.addEventListener('touchstart', (e) => this._touchStartHandler(e), { passive: true });
  60. break;
  61. }
  62. }).catch((err) => {
  63. this.error(err, null, null);
  64. });
  65. }
  66. show(page) {
  67. if (!page) {
  68. page = (!this._model && !this.activeGraph) ? 'welcome' : 'default';
  69. }
  70. this._host.screen(page);
  71. if (this._sidebar) {
  72. this._sidebar.close();
  73. }
  74. this._host.document.body.setAttribute('class', page);
  75. if (page === 'default') {
  76. const container = this._getElementById('graph');
  77. if (container) {
  78. container.focus();
  79. }
  80. }
  81. if (page === 'welcome') {
  82. const element = this._getElementById('open-file-button');
  83. if (element) {
  84. element.focus();
  85. }
  86. }
  87. this._page = page;
  88. }
  89. cut() {
  90. this._host.document.execCommand('cut');
  91. }
  92. copy() {
  93. this._host.document.execCommand('copy');
  94. }
  95. paste() {
  96. this._host.document.execCommand('paste');
  97. }
  98. selectAll() {
  99. this._host.document.execCommand('selectall');
  100. }
  101. find() {
  102. if (this._graph) {
  103. this.clearSelection();
  104. const graphElement = this._getElementById('canvas');
  105. const view = new sidebar.FindSidebar(this._host, graphElement, this._graph);
  106. view.on('search-text-changed', (sender, text) => {
  107. this._searchText = text;
  108. });
  109. view.on('select', (sender, selection) => {
  110. this.select(selection);
  111. });
  112. this._sidebar.open(view.content, 'Find');
  113. view.focus(this._searchText);
  114. }
  115. }
  116. get model() {
  117. return this._model;
  118. }
  119. get options() {
  120. return this._options;
  121. }
  122. toggle(name) {
  123. switch (name) {
  124. case 'names':
  125. case 'attributes':
  126. case 'initializers':
  127. this._options[name] = !this._options[name];
  128. this._reload();
  129. break;
  130. case 'direction':
  131. this._options.direction = this._options.direction === 'vertical' ? 'horizontal' : 'vertical';
  132. this._reload();
  133. break;
  134. case 'mousewheel':
  135. this._options.mousewheel = this._options.mousewheel === 'scroll' ? 'zoom' : 'scroll';
  136. break;
  137. default:
  138. throw new view.Error("Unsupported toogle '" + name + "'.");
  139. }
  140. }
  141. _reload() {
  142. this.show('welcome spinner');
  143. if (this._model && this._graphs.length > 0) {
  144. this._updateGraph(this._model, this._graphs).catch((error) => {
  145. if (error) {
  146. this.error(error, 'Graph update failed.', 'welcome');
  147. }
  148. });
  149. }
  150. }
  151. _timeout(time) {
  152. return new Promise((resolve) => {
  153. setTimeout(() => {
  154. resolve();
  155. }, time);
  156. });
  157. }
  158. _getElementById(id) {
  159. return this._host.document.getElementById(id + this._id);
  160. }
  161. zoomIn() {
  162. this._updateZoom(this._zoom * 1.1);
  163. }
  164. zoomOut() {
  165. this._updateZoom(this._zoom * 0.9);
  166. }
  167. resetZoom() {
  168. this._updateZoom(1);
  169. }
  170. _preventDefault(e) {
  171. if (e.shiftKey || e.ctrlKey) {
  172. e.preventDefault();
  173. }
  174. }
  175. _updateZoom(zoom, e) {
  176. const container = this._getElementById('graph');
  177. const canvas = this._getElementById('canvas');
  178. const limit = this._options.direction === 'vertical' ?
  179. container.clientHeight / this._height :
  180. container.clientWidth / this._width;
  181. const min = Math.min(Math.max(limit, 0.15), 1);
  182. zoom = Math.max(min, Math.min(zoom, 1.4));
  183. const scrollLeft = this._scrollLeft || container.scrollLeft;
  184. const scrollTop = this._scrollTop || container.scrollTop;
  185. const x = (e ? e.pageX : (container.clientWidth / 2)) + scrollLeft;
  186. const y = (e ? e.pageY : (container.clientHeight / 2)) + scrollTop;
  187. const width = zoom * this._width;
  188. const height = zoom * this._height;
  189. canvas.style.width = width + 'px';
  190. canvas.style.height = height + 'px';
  191. this._scrollLeft = Math.max(0, ((x * zoom) / this._zoom) - (x - scrollLeft));
  192. this._scrollTop = Math.max(0, ((y * zoom) / this._zoom) - (y - scrollTop));
  193. container.scrollLeft = this._scrollLeft;
  194. container.scrollTop = this._scrollTop;
  195. this._zoom = zoom;
  196. }
  197. _mouseDownHandler(e) {
  198. if (e.buttons === 1) {
  199. const document = this._host.document.documentElement;
  200. document.style.cursor = 'grabbing';
  201. const container = this._getElementById('graph');
  202. this._mousePosition = {
  203. left: container.scrollLeft,
  204. top: container.scrollTop,
  205. x: e.clientX,
  206. y: e.clientY
  207. };
  208. e.stopImmediatePropagation();
  209. const mouseMoveHandler = (e) => {
  210. e.preventDefault();
  211. e.stopImmediatePropagation();
  212. const dx = e.clientX - this._mousePosition.x;
  213. const dy = e.clientY - this._mousePosition.y;
  214. this._mousePosition.moved = dx * dx + dy * dy > 0;
  215. if (this._mousePosition.moved) {
  216. const container = this._getElementById('graph');
  217. container.scrollTop = this._mousePosition.top - dy;
  218. container.scrollLeft = this._mousePosition.left - dx;
  219. }
  220. };
  221. const mouseUpHandler = () => {
  222. document.style.cursor = null;
  223. container.removeEventListener('mouseup', mouseUpHandler);
  224. container.removeEventListener('mouseleave', mouseUpHandler);
  225. container.removeEventListener('mousemove', mouseMoveHandler);
  226. if (this._mousePosition && this._mousePosition.moved) {
  227. e.preventDefault();
  228. e.stopImmediatePropagation();
  229. delete this._mousePosition;
  230. document.addEventListener('click', clickHandler, true);
  231. }
  232. };
  233. const clickHandler = (e) => {
  234. e.stopPropagation();
  235. document.removeEventListener('click', clickHandler, true);
  236. };
  237. container.addEventListener('mousemove', mouseMoveHandler);
  238. container.addEventListener('mouseup', mouseUpHandler);
  239. container.addEventListener('mouseleave', mouseUpHandler);
  240. }
  241. }
  242. _touchStartHandler(e) {
  243. if (e.touches.length === 2) {
  244. this._touchPoints = Array.from(e.touches);
  245. this._touchZoom = this._zoom;
  246. }
  247. const touchMoveHandler = (e) => {
  248. if (Array.isArray(this._touchPoints) && this._touchPoints.length === 2 && e.touches.length === 2) {
  249. const distance = (points) => {
  250. const dx =(points[1].clientX - points[0].clientX);
  251. const dy =(points[1].clientY - points[0].clientY);
  252. return Math.sqrt(dx * dx + dy * dy);
  253. };
  254. const d1 = distance(Array.from(e.touches));
  255. const d2 = distance(this._touchPoints);
  256. if (d2 !== 0) {
  257. const points = this._touchPoints;
  258. const e = {
  259. pageX: (points[1].pageX + points[0].pageX) / 2,
  260. pageY: (points[1].pageY + points[0].pageY) / 2
  261. };
  262. const zoom = d2 === 0 ? d1 : d1 / d2;
  263. this._updateZoom(this._touchZoom * zoom, e);
  264. }
  265. }
  266. };
  267. const touchEndHandler = () => {
  268. container.removeEventListener('touchmove', touchMoveHandler, { passive: true });
  269. container.removeEventListener('touchcancel', touchEndHandler, { passive: true });
  270. container.removeEventListener('touchend', touchEndHandler, { passive: true });
  271. delete this._touchPoints;
  272. delete this._touchZoom;
  273. };
  274. const container = this._getElementById('graph');
  275. container.addEventListener('touchmove', touchMoveHandler, { passive: true });
  276. container.addEventListener('touchcancel', touchEndHandler, { passive: true });
  277. container.addEventListener('touchend', touchEndHandler, { passive: true });
  278. }
  279. _gestureStartHandler(e) {
  280. e.preventDefault();
  281. this._gestureZoom = this._zoom;
  282. const container = this._getElementById('graph');
  283. const gestureChangeHandler = (e) => {
  284. e.preventDefault();
  285. this._updateZoom(this._gestureZoom * e.scale, e);
  286. };
  287. const gestureEndHandler = (e) => {
  288. container.removeEventListener('gesturechange', gestureChangeHandler, false);
  289. container.removeEventListener('gestureend', gestureEndHandler, false);
  290. e.preventDefault();
  291. if (this._gestureZoom) {
  292. this._updateZoom(this._gestureZoom * e.scale, e);
  293. delete this._gestureZoom;
  294. }
  295. };
  296. container.addEventListener('gesturechange', gestureChangeHandler, false);
  297. container.addEventListener('gestureend', gestureEndHandler, false);
  298. }
  299. _scrollHandler(e) {
  300. if (this._scrollLeft && e.target.scrollLeft !== Math.floor(this._scrollLeft)) {
  301. delete this._scrollLeft;
  302. }
  303. if (this._scrollTop && e.target.scrollTop !== Math.floor(this._scrollTop)) {
  304. delete this._scrollTop;
  305. }
  306. }
  307. _wheelHandler(e) {
  308. if (e.shiftKey || e.ctrlKey || this._options.mousewheel === 'zoom') {
  309. const delta = -e.deltaY * (e.deltaMode === 1 ? 0.05 : e.deltaMode ? 1 : 0.002) * (e.ctrlKey ? 10 : 1);
  310. this._updateZoom(this._zoom * Math.pow(2, delta), e);
  311. e.preventDefault();
  312. }
  313. }
  314. select(selection) {
  315. this.clearSelection();
  316. if (selection && selection.length > 0) {
  317. const container = this._getElementById('graph');
  318. let x = 0;
  319. let y = 0;
  320. for (const element of selection) {
  321. element.classList.add('select');
  322. this._selection.push(element);
  323. const rect = element.getBoundingClientRect();
  324. x += rect.left + (rect.width / 2);
  325. y += rect.top + (rect.height / 2);
  326. }
  327. x = x / selection.length;
  328. y = y / selection.length;
  329. const rect = container.getBoundingClientRect();
  330. const left = (container.scrollLeft + x - rect.left) - (rect.width / 2);
  331. const top = (container.scrollTop + y - rect.top) - (rect.height / 2);
  332. container.scrollTo({ left: left, top: top, behavior: 'smooth' });
  333. }
  334. }
  335. clearSelection() {
  336. while (this._selection.length > 0) {
  337. const element = this._selection.pop();
  338. element.classList.remove('select');
  339. }
  340. }
  341. error(err, name, screen) {
  342. if (this._sidebar) {
  343. this._sidebar.close();
  344. }
  345. this._host.exception(err, false);
  346. const knowns = [
  347. { name: '', message: /^Invalid argument identifier/, url: 'https://github.com/lutzroeder/netron/issues/540' },
  348. { name: '', message: /^Cannot read property/, url: 'https://github.com/lutzroeder/netron/issues/647' },
  349. { name: '', message: /^Failed to render tensor/, url: 'https://github.com/lutzroeder/netron/issues/681' },
  350. { name: 'Error', message: /^EPERM: operation not permitted/, url: 'https://github.com/lutzroeder/netron/issues/551' },
  351. { name: 'Error', message: /^EACCES: permission denied/, url: 'https://github.com/lutzroeder/netron/issues/504' },
  352. { name: 'RangeError', message: /^Offset is outside the bounds of the DataView/, url: 'https://github.com/lutzroeder/netron/issues/563' },
  353. { name: 'RangeError', message: /^start offset of Int32Array/, url: 'https://github.com/lutzroeder/netron/issues/565' },
  354. { name: 'RangeError', message: /^Maximum call stack size exceeded/, url: 'https://github.com/lutzroeder/netron/issues/589' },
  355. { name: 'RangeError', message: /^Invalid string length/, url: 'https://github.com/lutzroeder/netron/issues/648' },
  356. { name: 'Error loading model.', message: /^Unsupported file content \(/, url: 'https://github.com/lutzroeder/netron/issues/550' },
  357. { name: 'Error loading model.', message: /^Unsupported Protocol Buffers content/, url: 'https://github.com/lutzroeder/netron/issues/593' },
  358. { name: 'Error loading model.', message: /^Unsupported Protocol Buffers text content/, url: 'https://github.com/lutzroeder/netron/issues/594' },
  359. { name: 'Error loading model.', message: /^Unsupported JSON content/, url: 'https://github.com/lutzroeder/netron/issues/595' },
  360. { name: 'Error loading Caffe model.', message: /^File format is not caffe\.NetParameter/, url: 'https://github.com/lutzroeder/netron/issues/563' },
  361. { name: 'Error loading Darknet model.', message: /^Invalid tensor shape/, url: 'https://github.com/lutzroeder/netron/issues/541' },
  362. { name: 'Error loading DaVinci model.', message: /^Unsupported attribute type/, url: 'https://github.com/lutzroeder/netron/issues/926' },
  363. { name: 'Error loading Keras model.', message: /^Unsupported data object header version/, url: 'https://github.com/lutzroeder/netron/issues/548' },
  364. { name: 'Error loading MNN model.', message: /^File format is not mnn\.Net/, url: 'https://github.com/lutzroeder/netron/issues/746' },
  365. { name: 'Error loading PyTorch model.', message: /^File does not contain root module or state dictionary/, url: 'https://github.com/lutzroeder/netron/issues/543' },
  366. { name: 'Error loading PyTorch model.', message: /^Module does not contain modules/, url: 'https://github.com/lutzroeder/netron/issues/544' },
  367. { name: 'Error loading PyTorch model.', message: /^Failed to resolve module/, url: 'https://github.com/lutzroeder/netron/issues/545' },
  368. { name: 'Error loading PyTorch model.', message: /^Unsupported function/, url: 'https://github.com/lutzroeder/netron/issues/546' },
  369. { name: 'Error loading PyTorch model.', message: /^Unsupported uninitialized argument/, url: 'https://github.com/lutzroeder/netron/issues/547' },
  370. { name: 'Error loading ONNX model.', message: /^File format is not onnx\.ModelProto/, url: 'https://github.com/lutzroeder/netron/issues/549' },
  371. { name: 'Error loading TensorFlow model.', message: /^File text format is not TensorFlow\.js graph-model/, url: 'https://github.com/lutzroeder/netron/issues/764' },
  372. { name: 'Error loading TensorFlow Lite model.', message: /^Offset is outside the bounds of the DataView/, url: 'https://github.com/lutzroeder/netron/issues/563' },
  373. { name: 'Error loading UFF model.', message: /^Unknown attribute/, url: 'https://github.com/lutzroeder/netron/issues/649' }
  374. ];
  375. const known = knowns.find((known) => (known.name.length === 0 || known.name === err.name) && err.message.match(known.message));
  376. const message = err.message + (known ? '\n\nPlease provide information about this issue at ' + known.url + '.' : '');
  377. name = name || err.name;
  378. this._host.error(name, message);
  379. this.show(screen !== undefined ? screen : 'welcome');
  380. if (known) {
  381. this._host.openURL(known.url);
  382. }
  383. }
  384. accept(file) {
  385. return this._modelFactoryService.accept(file);
  386. }
  387. open(context) {
  388. this._host.event('Model', 'Open', 'Size', context.stream ? context.stream.length : 0);
  389. this._sidebar.close();
  390. return this._timeout(2).then(() => {
  391. return this._modelFactoryService.open(context).then((model) => {
  392. const format = [];
  393. if (model.format) {
  394. format.push(model.format);
  395. }
  396. if (model.producer) {
  397. format.push('(' + model.producer + ')');
  398. }
  399. if (format.length > 0) {
  400. this._host.event('Model', 'Format', format.join(' '));
  401. }
  402. return this._timeout(20).then(() => {
  403. const graphs = Array.isArray(model.graphs) && model.graphs.length > 0 ? [ model.graphs[0] ] : [];
  404. return this._updateGraph(model, graphs);
  405. });
  406. });
  407. });
  408. }
  409. _updateActiveGraph(graph) {
  410. this._sidebar.close();
  411. if (this._model) {
  412. const model = this._model;
  413. this.show('welcome spinner');
  414. this._timeout(200).then(() => {
  415. return this._updateGraph(model, [ graph ]).catch((error) => {
  416. if (error) {
  417. this.error(error, 'Graph update failed.', 'welcome');
  418. }
  419. });
  420. });
  421. }
  422. }
  423. get activeGraph() {
  424. return Array.isArray(this._graphs) && this._graphs.length > 0 ? this._graphs[0] : null;
  425. }
  426. _updateGraph(model, graphs) {
  427. return this._timeout(100).then(() => {
  428. const graph = Array.isArray(graphs) && graphs.length > 0 ? graphs[0] : null;
  429. if (graph && graph != this._graphs[0]) {
  430. const nodes = graph.nodes;
  431. if (nodes.length > 2048) {
  432. if (!this._host.confirm('Large model detected.', 'This graph contains a large number of nodes and might take a long time to render. Do you want to continue?')) {
  433. this._host.event('Graph', 'Render', 'Skip', nodes.length);
  434. this.show(null);
  435. return null;
  436. }
  437. }
  438. }
  439. const update = () => {
  440. const nameButton = this._getElementById('name-button');
  441. const backButton = this._getElementById('back-button');
  442. if (this._graphs.length > 1) {
  443. const graph = this.activeGraph;
  444. nameButton.innerHTML = graph ? graph.name : '';
  445. backButton.style.opacity = 1;
  446. nameButton.style.opacity = 1;
  447. }
  448. else {
  449. backButton.style.opacity = 0;
  450. nameButton.style.opacity = 0;
  451. }
  452. };
  453. const lastModel = this._model;
  454. const lastGraphs = this._graphs;
  455. this._model = model;
  456. this._graphs = graphs;
  457. return this.renderGraph(this._model, this.activeGraph).then(() => {
  458. if (this._page !== 'default') {
  459. this.show('default');
  460. }
  461. update();
  462. return this._model;
  463. }).catch((error) => {
  464. this._model = lastModel;
  465. this._graphs = lastGraphs;
  466. return this.renderGraph(this._model, this.activeGraph).then(() => {
  467. if (this._page !== 'default') {
  468. this.show('default');
  469. }
  470. update();
  471. throw error;
  472. });
  473. });
  474. });
  475. }
  476. pushGraph(graph) {
  477. if (graph !== this.activeGraph) {
  478. this._sidebar.close();
  479. this._updateGraph(this._model, [ graph ].concat(this._graphs));
  480. }
  481. }
  482. popGraph() {
  483. if (this._graphs.length > 1) {
  484. this._sidebar.close();
  485. return this._updateGraph(this._model, this._graphs.slice(1));
  486. }
  487. return null;
  488. }
  489. renderGraph(model, graph) {
  490. try {
  491. this._graph = null;
  492. const canvas = this._getElementById('canvas');
  493. while (canvas.lastChild) {
  494. canvas.removeChild(canvas.lastChild);
  495. }
  496. if (!graph) {
  497. return Promise.resolve();
  498. }
  499. this._zoom = 1;
  500. const groups = graph.groups;
  501. const nodes = graph.nodes;
  502. this._host.event('Graph', 'Render', 'Size', nodes.length);
  503. const options = {};
  504. options.nodesep = 20;
  505. options.ranksep = 20;
  506. const rotate = graph.nodes.every((node) => node.inputs.filter((input) => input.arguments.every((argument) => !argument.initializer)).length === 0 && node.outputs.length === 0);
  507. const horizontal = rotate ? this._options.direction === 'vertical' : this._options.direction !== 'vertical';
  508. if (horizontal) {
  509. options.rankdir = "LR";
  510. }
  511. if (nodes.length > 3000) {
  512. options.ranker = 'longest-path';
  513. }
  514. const viewGraph = new view.Graph(this, model, groups, options);
  515. viewGraph.add(graph);
  516. // Workaround for Safari background drag/zoom issue:
  517. // https://stackoverflow.com/questions/40887193/d3-js-zoom-is-not-working-with-mousewheel-in-safari
  518. const background = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'rect');
  519. background.setAttribute('id', 'background');
  520. background.setAttribute('fill', 'none');
  521. background.setAttribute('pointer-events', 'all');
  522. canvas.appendChild(background);
  523. const origin = this._host.document.createElementNS('http://www.w3.org/2000/svg', 'g');
  524. origin.setAttribute('id', 'origin');
  525. canvas.appendChild(origin);
  526. viewGraph.build(this._host.document, origin);
  527. this._zoom = 1;
  528. return this._timeout(20).then(() => {
  529. viewGraph.update();
  530. const elements = Array.from(canvas.getElementsByClassName('graph-input') || []);
  531. if (elements.length === 0) {
  532. const nodeElements = Array.from(canvas.getElementsByClassName('graph-node') || []);
  533. if (nodeElements.length > 0) {
  534. elements.push(nodeElements[0]);
  535. }
  536. }
  537. const size = canvas.getBBox();
  538. const margin = 100;
  539. const width = Math.ceil(margin + size.width + margin);
  540. const height = Math.ceil(margin + size.height + margin);
  541. origin.setAttribute('transform', 'translate(' + margin.toString() + ', ' + margin.toString() + ') scale(1)');
  542. background.setAttribute('width', width);
  543. background.setAttribute('height', height);
  544. this._width = width;
  545. this._height = height;
  546. delete this._scrollLeft;
  547. delete this._scrollRight;
  548. canvas.setAttribute('viewBox', '0 0 ' + width + ' ' + height);
  549. canvas.setAttribute('width', width);
  550. canvas.setAttribute('height', height);
  551. this._zoom = 1;
  552. this._updateZoom(this._zoom);
  553. const container = this._getElementById('graph');
  554. if (elements && elements.length > 0) {
  555. // Center view based on input elements
  556. const xs = [];
  557. const ys = [];
  558. for (let i = 0; i < elements.length; i++) {
  559. const element = elements[i];
  560. const rect = element.getBoundingClientRect();
  561. xs.push(rect.left + (rect.width / 2));
  562. ys.push(rect.top + (rect.height / 2));
  563. }
  564. let x = xs[0];
  565. const y = ys[0];
  566. if (ys.every(y => y === ys[0])) {
  567. x = xs.reduce((a, b) => a + b, 0) / xs.length;
  568. }
  569. const graphRect = container.getBoundingClientRect();
  570. const left = (container.scrollLeft + x - graphRect.left) - (graphRect.width / 2);
  571. const top = (container.scrollTop + y - graphRect.top) - (graphRect.height / 2);
  572. container.scrollTo({ left: left, top: top, behavior: 'auto' });
  573. }
  574. else {
  575. const canvasRect = canvas.getBoundingClientRect();
  576. const graphRect = container.getBoundingClientRect();
  577. const left = (container.scrollLeft + (canvasRect.width / 2) - graphRect.left) - (graphRect.width / 2);
  578. const top = (container.scrollTop + (canvasRect.height / 2) - graphRect.top) - (graphRect.height / 2);
  579. container.scrollTo({ left: left, top: top, behavior: 'auto' });
  580. }
  581. this._graph = viewGraph;
  582. return;
  583. });
  584. }
  585. catch (error) {
  586. return Promise.reject(error);
  587. }
  588. }
  589. applyStyleSheet(element, name) {
  590. let rules = [];
  591. for (const styleSheet of this._host.document.styleSheets) {
  592. if (styleSheet && styleSheet.href && styleSheet.href.endsWith('/' + name)) {
  593. rules = styleSheet.cssRules;
  594. break;
  595. }
  596. }
  597. const nodes = element.getElementsByTagName('*');
  598. for (const node of nodes) {
  599. for (const rule of rules) {
  600. if (node.matches(rule.selectorText)) {
  601. for (const item of rule.style) {
  602. node.style[item] = rule.style[item];
  603. }
  604. }
  605. }
  606. }
  607. }
  608. export(file) {
  609. const lastIndex = file.lastIndexOf('.');
  610. const extension = (lastIndex != -1) ? file.substring(lastIndex + 1) : '';
  611. if (this.activeGraph && (extension === 'png' || extension === 'svg')) {
  612. const canvas = this._getElementById('canvas');
  613. const clone = canvas.cloneNode(true);
  614. this.applyStyleSheet(clone, 'view-grapher.css');
  615. clone.setAttribute('id', 'export');
  616. clone.removeAttribute('viewBox');
  617. clone.removeAttribute('width');
  618. clone.removeAttribute('height');
  619. clone.style.removeProperty('opacity');
  620. clone.style.removeProperty('display');
  621. clone.style.removeProperty('width');
  622. clone.style.removeProperty('height');
  623. const background = clone.querySelector('#background');
  624. const origin = clone.querySelector('#origin');
  625. origin.setAttribute('transform', 'translate(0,0) scale(1)');
  626. background.removeAttribute('width');
  627. background.removeAttribute('height');
  628. const parent = canvas.parentElement;
  629. parent.insertBefore(clone, canvas);
  630. const size = clone.getBBox();
  631. parent.removeChild(clone);
  632. parent.removeChild(canvas);
  633. parent.appendChild(canvas);
  634. const delta = (Math.min(size.width, size.height) / 2.0) * 0.1;
  635. const width = Math.ceil(delta + size.width + delta);
  636. const height = Math.ceil(delta + size.height + delta);
  637. origin.setAttribute('transform', 'translate(' + (delta - size.x).toString() + ', ' + (delta - size.y).toString() + ') scale(1)');
  638. clone.setAttribute('width', width);
  639. clone.setAttribute('height', height);
  640. background.setAttribute('width', width);
  641. background.setAttribute('height', height);
  642. background.setAttribute('fill', '#fff');
  643. const data = new XMLSerializer().serializeToString(clone);
  644. if (extension === 'svg') {
  645. const blob = new Blob([ data ], { type: 'image/svg' });
  646. this._host.export(file, blob);
  647. }
  648. if (extension === 'png') {
  649. const image = new Image();
  650. image.onload = () => {
  651. const max = Math.max(width, height);
  652. const scale = Math.min(24000.0 / max, 2.0);
  653. const canvas = this._host.document.createElement('canvas');
  654. canvas.width = Math.ceil(width * scale);
  655. canvas.height = Math.ceil(height * scale);
  656. const context = canvas.getContext('2d');
  657. context.scale(scale, scale);
  658. context.drawImage(image, 0, 0);
  659. canvas.toBlob((blob) => {
  660. if (blob) {
  661. this._host.export(file, blob);
  662. }
  663. else {
  664. const err = new Error();
  665. err.name = 'Error exporting image.';
  666. err.message = 'Image may be too large to render as PNG.';
  667. this._host.exception(err, false);
  668. this._host.error(err.name, err.message);
  669. }
  670. }, 'image/png');
  671. };
  672. image.src = 'data:image/svg+xml;base64,' + this._host.window.btoa(unescape(encodeURIComponent(data)));
  673. }
  674. }
  675. }
  676. showModelProperties() {
  677. if (this._model) {
  678. try {
  679. const modelSidebar = new sidebar.ModelSidebar(this._host, this._model, this.activeGraph);
  680. modelSidebar.on('update-active-graph', (sender, graph) => {
  681. this._updateActiveGraph(graph);
  682. });
  683. const content = modelSidebar.render();
  684. this._sidebar.open(content, 'Model Properties');
  685. }
  686. catch (error) {
  687. const content = " in '" + this._model.identifier + "'.";
  688. if (error && !error.message.endsWith(content) && (error.context === undefined || error.context === true)) {
  689. error.message = error.message.replace(/\.$/, '') + content;
  690. }
  691. this.error(error, 'Error showing model properties.', null);
  692. }
  693. }
  694. }
  695. showNodeProperties(node, input) {
  696. if (node) {
  697. try {
  698. const nodeSidebar = new sidebar.NodeSidebar(this._host, node);
  699. nodeSidebar.on('show-documentation', (/* sender, e */) => {
  700. this.showDocumentation(node.type);
  701. });
  702. nodeSidebar.on('show-graph', (sender, graph) => {
  703. this.pushGraph(graph);
  704. });
  705. nodeSidebar.on('export-tensor', (sender, tensor) => {
  706. const defaultPath = tensor.name ? tensor.name.split('/').join('_').split(':').join('_').split('.').join('_') : 'tensor';
  707. this._host.save('NumPy Array', 'npy', defaultPath, (file) => {
  708. try {
  709. let data_type = tensor.type.dataType;
  710. if (data_type === 'boolean') {
  711. data_type = 'bool';
  712. }
  713. const execution = new python.Execution(null);
  714. const bytes = execution.invoke('io.BytesIO', []);
  715. const dtype = execution.invoke('numpy.dtype', [ data_type ]);
  716. const array = execution.invoke('numpy.asarray', [ tensor.value, dtype ]);
  717. execution.invoke('numpy.save', [ bytes, array ]);
  718. bytes.seek(0);
  719. const blob = new Blob([ bytes.read() ], { type: 'application/octet-stream' });
  720. this._host.export(file, blob);
  721. }
  722. catch (error) {
  723. this.error(error, 'Error saving NumPy tensor.', null);
  724. }
  725. });
  726. });
  727. nodeSidebar.on('error', (sender, error) => {
  728. if (this._model) {
  729. error.message = error.message.replace(/\.$/, '') + " in '" + this._model.identifier + "'.";
  730. }
  731. this.error(error, null, null);
  732. });
  733. if (input) {
  734. nodeSidebar.toggleInput(input.name);
  735. }
  736. this._sidebar.open(nodeSidebar.render(), 'Node Properties');
  737. }
  738. catch (error) {
  739. const content = " in '" + this._model.identifier + "'.";
  740. if (error && !error.message.endsWith(content) && (error.context === undefined || error.context === true)) {
  741. error.message = error.message.replace(/\.$/, '') + content;
  742. }
  743. this.error(error, 'Error showing node properties.', null);
  744. }
  745. }
  746. }
  747. showDocumentation(type) {
  748. if (type && (type.description || type.inputs || type.outputs || type.attributes)) {
  749. if (type.nodes && type.nodes.length > 0) {
  750. this.pushGraph(type);
  751. }
  752. const documentationSidebar = new sidebar.DocumentationSidebar(this._host, type);
  753. documentationSidebar.on('navigate', (sender, e) => {
  754. this._host.openURL(e.link);
  755. });
  756. const title = type.type === 'function' ? 'Function' : 'Documentation';
  757. this._sidebar.push(documentationSidebar.render(), title);
  758. }
  759. }
  760. };
  761. view.Graph = class extends grapher.Graph {
  762. constructor(view, model, compound, options) {
  763. super(compound, options);
  764. this.view = view;
  765. this.model = model;
  766. this._arguments = new Map();
  767. this._nodeKey = 0;
  768. }
  769. createNode(node) {
  770. const value = new view.Node(this, node);
  771. value.name = (this._nodeKey++).toString();
  772. // value.name = node.name;
  773. this.setNode(value);
  774. return value;
  775. }
  776. createInput(input) {
  777. const value = new view.Input(this, input);
  778. value.name = (this._nodeKey++).toString();
  779. this.setNode(value);
  780. return value;
  781. }
  782. createOutput(output) {
  783. const value = new view.Output(this, output);
  784. value.name = (this._nodeKey++).toString();
  785. this.setNode(value);
  786. return value;
  787. }
  788. createArgument(argument) {
  789. const name = argument.name;
  790. if (!this._arguments.has(name)) {
  791. this._arguments.set(name, new view.Argument(this, argument));
  792. }
  793. return this._arguments.get(name);
  794. }
  795. createEdge(from, to) {
  796. const value = new view.Edge(from, to);
  797. return value;
  798. }
  799. add(graph) {
  800. const clusters = new Set();
  801. const clusterParentMap = new Map();
  802. const groups = graph.groups;
  803. if (groups) {
  804. for (const node of graph.nodes) {
  805. if (node.group) {
  806. const path = node.group.split('/');
  807. while (path.length > 0) {
  808. const name = path.join('/');
  809. path.pop();
  810. clusterParentMap.set(name, path.join('/'));
  811. }
  812. }
  813. }
  814. }
  815. for (const input of graph.inputs) {
  816. const viewInput = this.createInput(input);
  817. for (const argument of input.arguments) {
  818. this.createArgument(argument).from(viewInput);
  819. }
  820. }
  821. for (const node of graph.nodes) {
  822. const viewNode = this.createNode(node);
  823. const inputs = node.inputs;
  824. for (const input of inputs) {
  825. for (const argument of input.arguments) {
  826. if (argument.name != '' && !argument.initializer) {
  827. this.createArgument(argument).to(viewNode);
  828. }
  829. }
  830. }
  831. let outputs = node.outputs;
  832. if (node.chain && node.chain.length > 0) {
  833. const chainOutputs = node.chain[node.chain.length - 1].outputs;
  834. if (chainOutputs.length > 0) {
  835. outputs = chainOutputs;
  836. }
  837. }
  838. for (const output of outputs) {
  839. for (const argument of output.arguments) {
  840. if (!argument) {
  841. throw new view.Error("Invalid null argument in '" + this.model.identifier + "'.");
  842. }
  843. if (argument.name != '') {
  844. this.createArgument(argument).from(viewNode);
  845. }
  846. }
  847. }
  848. if (node.controlDependencies && node.controlDependencies.length > 0) {
  849. for (const argument of node.controlDependencies) {
  850. this.createArgument(argument).to(viewNode, true);
  851. }
  852. }
  853. const createCluster = (name) => {
  854. if (!clusters.has(name)) {
  855. this.setNode({ name: name, rx: 5, ry: 5});
  856. clusters.add(name);
  857. const parent = clusterParentMap.get(name);
  858. if (parent) {
  859. createCluster(parent);
  860. this.setParent(name, parent);
  861. }
  862. }
  863. };
  864. if (groups) {
  865. let groupName = node.group;
  866. if (groupName && groupName.length > 0) {
  867. if (!clusterParentMap.has(groupName)) {
  868. const lastIndex = groupName.lastIndexOf('/');
  869. if (lastIndex != -1) {
  870. groupName = groupName.substring(0, lastIndex);
  871. if (!clusterParentMap.has(groupName)) {
  872. groupName = null;
  873. }
  874. }
  875. else {
  876. groupName = null;
  877. }
  878. }
  879. if (groupName) {
  880. createCluster(groupName + '\ngroup');
  881. this.setParent(viewNode.name, groupName + '\ngroup');
  882. }
  883. }
  884. }
  885. }
  886. for (const output of graph.outputs) {
  887. const viewOutput = this.createOutput(output);
  888. for (const argument of output.arguments) {
  889. this.createArgument(argument).to(viewOutput);
  890. }
  891. }
  892. }
  893. build(document, origin) {
  894. for (const argument of this._arguments.values()) {
  895. argument.build();
  896. }
  897. super.build(document, origin);
  898. }
  899. };
  900. view.Node = class extends grapher.Node {
  901. constructor(context, value) {
  902. super();
  903. this.context = context;
  904. this.value = value;
  905. view.Node.counter = view.Node.counter || 0;
  906. this.id = 'node-' + (value.name ? 'name-' + value.name : 'id-' + (view.Node.counter++).toString());
  907. this._add(this.value);
  908. }
  909. get class() {
  910. return 'graph-node';
  911. }
  912. get inputs() {
  913. return this.value.inputs;
  914. }
  915. get outputs() {
  916. return this.value.outputs;
  917. }
  918. _add(node) {
  919. const header = this.header();
  920. const styles = [ 'node-item-type' ];
  921. const type = node.type;
  922. const category = type && type.category ? type.category : '';
  923. if (category) {
  924. styles.push('node-item-type-' + category.toLowerCase());
  925. }
  926. if (typeof type.name !== 'string' || !type.name.split) { // #416
  927. const identifier = this.context.model && this.context.model.identifier ? this.context.model.identifier : '?';
  928. throw new view.Error("Unsupported node type '" + JSON.stringify(type.name) + "' in '" + identifier + "'.");
  929. }
  930. const content = this.context.view.options.names && (node.name || node.location) ? (node.name || node.location) : type.name.split('.').pop();
  931. const tooltip = this.context.view.options.names && (node.name || node.location) ? type.name : (node.name || node.location);
  932. const title = header.add(null, styles, content, tooltip);
  933. title.on('click', () => this.context.view.showNodeProperties(node, null));
  934. if (node.type.nodes && node.type.nodes.length > 0) {
  935. const definition = header.add(null, styles, '\u0192', 'Show Function Definition');
  936. definition.on('click', () => this.context.view.pushGraph(node.type));
  937. }
  938. if (node.nodes) {
  939. // this._expand = header.add(null, styles, '+', null);
  940. // this._expand.on('click', () => this.toggle());
  941. }
  942. const initializers = [];
  943. let hiddenInitializers = false;
  944. if (this.context.view.options.initializers) {
  945. for (const input of node.inputs) {
  946. if (input.visible && input.arguments.length === 1 && input.arguments[0].initializer != null) {
  947. initializers.push(input);
  948. }
  949. if ((!input.visible || input.arguments.length > 1) &&
  950. input.arguments.some((argument) => argument.initializer != null)) {
  951. hiddenInitializers = true;
  952. }
  953. }
  954. }
  955. let sortedAttributes = [];
  956. const attributes = node.attributes || [];
  957. if (this.context.view.options.attributes) {
  958. sortedAttributes = attributes.filter((attribute) => attribute.visible).slice();
  959. }
  960. sortedAttributes.sort((a, b) => {
  961. const au = a.name.toUpperCase();
  962. const bu = b.name.toUpperCase();
  963. return (au < bu) ? -1 : (au > bu) ? 1 : 0;
  964. });
  965. if (initializers.length > 0 || hiddenInitializers || sortedAttributes.length > 0) {
  966. const list = this.list();
  967. list.on('click', () => this.context.view.showNodeProperties(node));
  968. for (const initializer of initializers) {
  969. const argument = initializer.arguments[0];
  970. const type = argument.type;
  971. let shape = '';
  972. let separator = '';
  973. if (type && type.shape && type.shape.dimensions && Array.isArray(type.shape.dimensions)) {
  974. shape = '\u3008' + type.shape.dimensions.map((d) => d ? d : '?').join('\u00D7') + '\u3009';
  975. if (type.shape.dimensions.length === 0 && argument.initializer && !argument.initializer.state) {
  976. try {
  977. shape = argument.initializer.toString();
  978. if (shape && shape.length > 10) {
  979. shape = shape.substring(0, 10) + '\u2026';
  980. }
  981. separator = ' = ';
  982. }
  983. catch (err) {
  984. let type = '?';
  985. try {
  986. type = argument.initializer.type.toString();
  987. }
  988. catch (error) {
  989. // continue regardless of error
  990. }
  991. const identifier = this.context.view.model && this.context.view.model.identifier ? this.context.view.model.identifier : '?';
  992. throw new view.Error("Failed to render tensor of type '" + type + "' in '" + identifier + "' (" + err.message + ").");
  993. }
  994. }
  995. }
  996. list.add(argument.name ? 'initializer-' + argument.name : '', initializer.name, shape, type ? type.toString() : '', separator);
  997. }
  998. if (hiddenInitializers) {
  999. list.add(null, '\u3008' + '\u2026' + '\u3009', '', null, '');
  1000. }
  1001. for (const attribute of sortedAttributes) {
  1002. if (attribute.visible) {
  1003. let value = new sidebar.Formatter(attribute.value, attribute.type).toString();
  1004. if (value && value.length > 25) {
  1005. value = value.substring(0, 25) + '\u2026';
  1006. }
  1007. list.add(null, attribute.name, value, attribute.type, ' = ');
  1008. }
  1009. }
  1010. }
  1011. if (Array.isArray(node.chain) && node.chain.length > 0) {
  1012. for (const innerNode of node.chain) {
  1013. this._add(innerNode);
  1014. }
  1015. }
  1016. if (node.inner) {
  1017. this._add(node.inner);
  1018. }
  1019. if (node.nodes) {
  1020. // this.canvas = this.canvas();
  1021. }
  1022. }
  1023. toggle() {
  1024. this._expand.content = '-';
  1025. this._graph = new view.Graph(this.context.view, this.context.model, false, {});
  1026. this._graph.add(this.value);
  1027. // const document = this.element.ownerDocument;
  1028. // const parent = this.element.parentElement;
  1029. // this._graph.build(document, parent);
  1030. // this._graph.update();
  1031. this.canvas.width = 300;
  1032. this.canvas.height = 300;
  1033. this.layout();
  1034. this.context.update();
  1035. }
  1036. };
  1037. view.Input = class extends grapher.Node {
  1038. constructor(context, value) {
  1039. super();
  1040. this.context = context;
  1041. this.value = value;
  1042. view.Input.counter = view.Input.counter || 0;
  1043. const types = value.arguments.map((argument) => argument.type || '').join('\n');
  1044. let name = value.name || '';
  1045. if (name.length > 16) {
  1046. name = name.split('/').pop();
  1047. }
  1048. const header = this.header();
  1049. const title = header.add(null, [ 'graph-item-input' ], name, types);
  1050. title.on('click', () => this.context.view.showModelProperties());
  1051. this.id = 'input-' + (name ? 'name-' + name : 'id-' + (view.Input.counter++).toString());
  1052. }
  1053. get class() {
  1054. return 'graph-input';
  1055. }
  1056. get inputs() {
  1057. return [];
  1058. }
  1059. get outputs() {
  1060. return [ this.value ];
  1061. }
  1062. };
  1063. view.Output = class extends grapher.Node {
  1064. constructor(context, value) {
  1065. super();
  1066. this.context = context;
  1067. this.value = value;
  1068. const types = value.arguments.map((argument) => argument.type || '').join('\n');
  1069. let name = value.name || '';
  1070. if (name.length > 16) {
  1071. name = name.split('/').pop();
  1072. }
  1073. const header = this.header();
  1074. const title = header.add(null, [ 'graph-item-output' ], name, types);
  1075. title.on('click', () => this.context.view.showModelProperties());
  1076. }
  1077. get inputs() {
  1078. return [ this.value ];
  1079. }
  1080. get outputs() {
  1081. return [];
  1082. }
  1083. };
  1084. view.Argument = class {
  1085. constructor(context, argument) {
  1086. this.context = context;
  1087. this._argument = argument;
  1088. }
  1089. from(node) {
  1090. this._from = node;
  1091. }
  1092. to(node, controlDependency) {
  1093. this._to = this._to || [];
  1094. if (controlDependency) {
  1095. this._controlDependencies = this._controlDependencies || new Set();
  1096. this._controlDependencies.add(this._to.length);
  1097. }
  1098. this._to.push(node);
  1099. }
  1100. build() {
  1101. this._edges = this._edges || [];
  1102. if (this._from && this._to) {
  1103. for (let i = 0; i < this._to.length; i++) {
  1104. const to = this._to[i];
  1105. let content = '';
  1106. const type = this._argument.type;
  1107. if (type &&
  1108. type.shape &&
  1109. type.shape.dimensions &&
  1110. type.shape.dimensions.length > 0 &&
  1111. type.shape.dimensions.every((dim) => !dim || Number.isInteger(dim) || dim instanceof base.Int64 || (typeof dim === 'string'))) {
  1112. content = type.shape.dimensions.map((dim) => dim || '?').join('\u00D7');
  1113. content = content.length > 16 ? '' : content;
  1114. }
  1115. if (this.context.view.options.names) {
  1116. content = this._argument.name.split('\n').shift(); // custom argument id
  1117. }
  1118. const edge = this.context.createEdge(this._from, to);
  1119. edge.v = this._from.name;
  1120. edge.w = to.name;
  1121. if (content) {
  1122. edge.label = content;
  1123. }
  1124. edge.id = 'edge-' + this._argument.name;
  1125. if (this._controlDependencies && this._controlDependencies.has(i)) {
  1126. edge.class = 'edge-path-control-dependency';
  1127. }
  1128. this.context.setEdge(edge);
  1129. this._edges.push(edge);
  1130. }
  1131. }
  1132. }
  1133. };
  1134. view.Edge = class extends grapher.Edge {
  1135. constructor(from, to) {
  1136. super(from, to);
  1137. }
  1138. get minlen() {
  1139. if (this.from.inputs.every((parameter) => parameter.arguments.every((argument) => argument.initializer))) {
  1140. return 2;
  1141. }
  1142. return 1;
  1143. }
  1144. };
  1145. view.ModelContext = class {
  1146. constructor(context) {
  1147. this._context = context;
  1148. this._tags = new Map();
  1149. this._content = new Map();
  1150. let stream = context.stream;
  1151. const entries = context.entries;
  1152. if (!stream && entries && entries.size > 0) {
  1153. this._entries = entries;
  1154. this._format = '';
  1155. }
  1156. else {
  1157. this._entries = new Map();
  1158. const entry = context instanceof view.EntryContext;
  1159. const identifier = context.identifier;
  1160. try {
  1161. const archive = gzip.Archive.open(stream);
  1162. if (archive) {
  1163. this._entries = archive.entries;
  1164. this._format = 'gzip';
  1165. if (this._entries.size === 1) {
  1166. stream = this._entries.values().next().value;
  1167. }
  1168. }
  1169. }
  1170. catch (error) {
  1171. if (!entry) {
  1172. const message = error && error.message ? error.message : error.toString();
  1173. throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  1174. }
  1175. }
  1176. try {
  1177. const formats = new Map([ [ 'zip', zip ], [ 'tar', tar ] ]);
  1178. for (const pair of formats) {
  1179. const format = pair[0];
  1180. const module = pair[1];
  1181. const archive = module.Archive.open(stream);
  1182. if (archive) {
  1183. this._entries = archive.entries;
  1184. this._format = format;
  1185. break;
  1186. }
  1187. }
  1188. }
  1189. catch (error) {
  1190. if (!entry) {
  1191. const message = error && error.message ? error.message : error.toString();
  1192. throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  1193. }
  1194. }
  1195. }
  1196. }
  1197. get identifier() {
  1198. return this._context.identifier;
  1199. }
  1200. get stream() {
  1201. return this._context.stream;
  1202. }
  1203. request(file, encoding, base) {
  1204. return this._context.request(file, encoding, base);
  1205. }
  1206. require(id) {
  1207. return this._context.require(id);
  1208. }
  1209. exception(error, fatal) {
  1210. this._context.exception(error, fatal);
  1211. }
  1212. entries(format) {
  1213. if (format !== undefined && format !== this._format) {
  1214. return new Map();
  1215. }
  1216. return this._entries;
  1217. }
  1218. open(type) {
  1219. if (!this._content.has(type)) {
  1220. this._content.set(type, undefined);
  1221. const stream = this.stream;
  1222. if (stream) {
  1223. const position = stream.position;
  1224. const signatures = [
  1225. [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ], // HDF5
  1226. [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ] // PyTorch
  1227. ];
  1228. const skip =
  1229. signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) ||
  1230. Array.from(this._tags).some((pair) => pair[0] !== 'flatbuffers' && pair[1].size > 0) ||
  1231. Array.from(this._content.values()).some((obj) => obj !== undefined);
  1232. if (!skip) {
  1233. switch (type) {
  1234. case 'json': {
  1235. try {
  1236. const reader = json.TextReader.open(this.stream);
  1237. if (reader) {
  1238. const obj = reader.read();
  1239. this._content.set(type, obj);
  1240. }
  1241. }
  1242. catch (err) {
  1243. // continue regardless of error
  1244. }
  1245. break;
  1246. }
  1247. case 'json.gz': {
  1248. try {
  1249. const archive = gzip.Archive.open(this.stream);
  1250. if (archive) {
  1251. const entries = archive.entries;
  1252. if (entries.size === 1) {
  1253. const stream = entries.values().next().value;
  1254. const reader = json.TextReader.open(stream);
  1255. if (reader) {
  1256. const obj = reader.read();
  1257. this._content.set(type, obj);
  1258. }
  1259. }
  1260. }
  1261. }
  1262. catch (err) {
  1263. // continue regardless of error
  1264. }
  1265. break;
  1266. }
  1267. case 'pkl': {
  1268. let unpickler = null;
  1269. try {
  1270. if (stream.length > 2) {
  1271. const zlib = (stream) => {
  1272. const buffer = stream.peek(2);
  1273. if (buffer[0] === 0x78) {
  1274. const check = (buffer[0] << 8) + buffer[1];
  1275. if (check % 31 === 0) {
  1276. const archive = zip.Archive.open(stream);
  1277. return archive.entries.get('');
  1278. }
  1279. }
  1280. return stream;
  1281. };
  1282. const data = zlib(stream);
  1283. unpickler = python.Unpickler.open(data, () => {
  1284. return new python.Execution(null, (error, fatal) => {
  1285. const message = error && error.message ? error.message : error.toString();
  1286. this.exception(new view.Error(message.replace(/\.$/, '') + " in '" + this.identifier + "'."), fatal);
  1287. });
  1288. });
  1289. }
  1290. }
  1291. catch (err) {
  1292. // continue regardless of error
  1293. }
  1294. if (unpickler) {
  1295. unpickler.persistent_load = (saved_id) => {
  1296. return saved_id;
  1297. };
  1298. const obj = unpickler.load();
  1299. this._content.set(type, obj);
  1300. }
  1301. break;
  1302. }
  1303. default: {
  1304. throw new view.Error("Unsupported open format type '" + type + "'.");
  1305. }
  1306. }
  1307. }
  1308. if (stream.position !== position) {
  1309. stream.seek(0);
  1310. }
  1311. }
  1312. }
  1313. return this._content.get(type);
  1314. }
  1315. tags(type) {
  1316. if (!this._tags.has(type)) {
  1317. let tags = new Map();
  1318. const stream = this.stream;
  1319. if (stream) {
  1320. const position = stream.position;
  1321. const signatures = [
  1322. [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ], // HDF5
  1323. [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ], // PyTorch
  1324. [ 0x50, 0x4b ], // Zip
  1325. [ 0x1f, 0x8b ] // Gzip
  1326. ];
  1327. const skip =
  1328. signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) ||
  1329. (Array.from(this._tags).some((pair) => pair[0] !== 'flatbuffers' && pair[1].size > 0) && type !== 'pb+') ||
  1330. Array.from(this._content.values()).some((obj) => obj !== undefined);
  1331. if (!skip) {
  1332. try {
  1333. switch (type) {
  1334. case 'pbtxt': {
  1335. const reader = protobuf.TextReader.open(stream);
  1336. tags = reader ? reader.signature() : tags;
  1337. break;
  1338. }
  1339. case 'pb': {
  1340. const reader = protobuf.BinaryReader.open(stream);
  1341. tags = reader.signature();
  1342. break;
  1343. }
  1344. case 'pb+': {
  1345. const reader = protobuf.BinaryReader.open(stream);
  1346. tags = reader.decode();
  1347. break;
  1348. }
  1349. case 'flatbuffers': {
  1350. if (stream.length >= 8) {
  1351. const buffer = stream.peek(Math.min(32, stream.length));
  1352. const reader = flatbuffers.BinaryReader.open(buffer);
  1353. const identifier = reader.identifier;
  1354. if (identifier.length > 0) {
  1355. tags.set('file_identifier', identifier);
  1356. }
  1357. }
  1358. break;
  1359. }
  1360. case 'xml': {
  1361. const reader = xml.TextReader.open(stream);
  1362. if (reader) {
  1363. const document = reader.peek();
  1364. const element = document.documentElement;
  1365. const namespaceURI = element.namespaceURI;
  1366. const localName = element.localName;
  1367. const name = namespaceURI ? namespaceURI + ':' + localName : localName;
  1368. tags.set(name, element);
  1369. }
  1370. break;
  1371. }
  1372. default: {
  1373. throw new view.Error("Unsupported tags format type '" + type + "'.");
  1374. }
  1375. }
  1376. }
  1377. catch (error) {
  1378. tags.clear();
  1379. }
  1380. }
  1381. if (stream.position !== position) {
  1382. stream.seek(position);
  1383. }
  1384. }
  1385. this._tags.set(type, tags);
  1386. }
  1387. return this._tags.get(type);
  1388. }
  1389. metadata(name) {
  1390. return base.Metadata.open(this, name);
  1391. }
  1392. };
  1393. view.EntryContext = class {
  1394. constructor(host, entries, rootFolder, identifier, stream) {
  1395. this._host = host;
  1396. this._entries = new Map();
  1397. if (entries) {
  1398. for (const entry of entries) {
  1399. if (entry[0].startsWith(rootFolder)) {
  1400. const name = entry[0].substring(rootFolder.length);
  1401. this._entries.set(name, entry[1]);
  1402. }
  1403. }
  1404. }
  1405. this._identifier = identifier.substring(rootFolder.length);
  1406. this._stream = stream;
  1407. }
  1408. get identifier() {
  1409. return this._identifier;
  1410. }
  1411. get stream() {
  1412. return this._stream;
  1413. }
  1414. request(file, encoding, base) {
  1415. if (base === undefined) {
  1416. const stream = this._entries.get(file);
  1417. if (!stream) {
  1418. return Promise.reject(new Error('File not found.'));
  1419. }
  1420. if (encoding) {
  1421. const decoder = new TextDecoder(encoding);
  1422. const buffer = stream.peek();
  1423. const value = decoder.decode(buffer);
  1424. return Promise.resolve(value);
  1425. }
  1426. return Promise.resolve(stream);
  1427. }
  1428. return this._host.request(file, encoding, base);
  1429. }
  1430. require(id) {
  1431. return this._host.require(id);
  1432. }
  1433. exception(error, fatal) {
  1434. this._host.exception(error, fatal);
  1435. }
  1436. };
  1437. view.ArchiveError = class extends Error {
  1438. constructor(message) {
  1439. super(message);
  1440. this.name = 'Error loading archive.';
  1441. }
  1442. };
  1443. view.ModelFactoryService = class {
  1444. constructor(host) {
  1445. this._host = host;
  1446. this._extensions = new Set([ '.zip', '.tar', '.tar.gz', '.tgz', '.gz' ]);
  1447. this._factories = [];
  1448. this.register('./pytorch', [ '.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt' ], [ '.model' ]);
  1449. this.register('./onnx', [ '.onnx', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', 'onnxmodel' ]);
  1450. this.register('./mxnet', [ '.json', '.params' ], [ '.mar'] );
  1451. this.register('./coreml', [ '.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb' ], [ '.mlpackage' ]);
  1452. this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt' ]);
  1453. this.register('./caffe2', [ '.pb', '.pbtxt', '.prototxt' ]);
  1454. this.register('./torch', [ '.t7', '.net' ]);
  1455. this.register('./tflite', [ '.tflite', '.lite', '.tfl', '.bin', '.pb', '.tmfile', '.h5', '.model', '.json', '.txt' ]);
  1456. this.register('./circle', [ '.circle' ]);
  1457. this.register('./tf', [ '.pb', '.meta', '.pbtxt', '.prototxt', '.txt', '.pt', '.json', '.index', '.ckpt', '.graphdef', '.pbmm', /.data-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]$/, /^events.out.tfevents./ ], [ '.zip' ]);
  1458. this.register('./mediapipe', [ '.pbtxt' ]);
  1459. this.register('./uff', [ '.uff', '.pb', '.pbtxt', '.uff.txt', '.trt', '.engine' ]);
  1460. this.register('./tensorrt', [ '.trt', '.trtmodel', '.engine', '.model', '.txt', '.uff', '.pb', '.tmfile', '.onnx', '.pth', '.dnn', '.plan' ]);
  1461. this.register('./numpy', [ '.npz', '.npy', '.pkl', '.pickle' ]);
  1462. this.register('./lasagne', [ '.pkl', '.pickle', '.joblib', '.model', '.pkl.z', '.joblib.z' ]);
  1463. this.register('./lightgbm', [ '.txt', '.pkl', '.model' ]);
  1464. this.register('./keras', [ '.h5', '.hd5', '.hdf5', '.keras', '.json', '.cfg', '.model', '.pb', '.pth', '.weights', '.pkl', '.lite', '.tflite', '.ckpt' ], [ '.zip' ]);
  1465. this.register('./sklearn', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5', '.pkl.z', '.joblib.z' ]);
  1466. this.register('./megengine', ['.tm']);
  1467. this.register('./pickle', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5', '.pkl.z', '.joblib.z', '.pdstates' ]);
  1468. this.register('./cntk', [ '.model', '.cntk', '.cmf', '.dnn' ]);
  1469. this.register('./paddle', [ '.pdmodel', '.pdiparams', '.pdparams', '.pdopt', '.paddle', '__model__', '.__model__', '.pbtxt', '.txt', '.tar', '.tar.gz', '.nb' ]);
  1470. this.register('./bigdl', [ '.model', '.bigdl' ]);
  1471. this.register('./darknet', [ '.cfg', '.model', '.txt', '.weights' ]);
  1472. this.register('./weka', [ '.model' ]);
  1473. this.register('./rknn', [ '.rknn', '.nb', '.onnx' ]);
  1474. this.register('./dlc', [ '.dlc', 'model', '.params' ]);
  1475. this.register('./armnn', [ '.armnn', '.json' ]);
  1476. this.register('./mnn', ['.mnn']);
  1477. this.register('./ncnn', [ '.param', '.bin', '.cfg.ncnn', '.weights.ncnn', '.ncnnmodel' ]);
  1478. this.register('./tnn', [ '.tnnproto', '.tnnmodel' ]);
  1479. this.register('./tengine', ['.tmfile']);
  1480. this.register('./mslite', [ '.ms']);
  1481. this.register('./barracuda', [ '.nn' ]);
  1482. this.register('./dnn', [ '.dnn' ]);
  1483. this.register('./xmodel', [ '.xmodel' ]);
  1484. this.register('./kmodel', [ '.kmodel' ]);
  1485. this.register('./flux', [ '.bson' ]);
  1486. this.register('./dl4j', [ '.json', '.bin' ]);
  1487. this.register('./openvino', [ '.xml', '.bin' ]);
  1488. this.register('./mlnet', [ '.zip' ]);
  1489. this.register('./acuity', [ '.json' ]);
  1490. this.register('./imgdnn', [ '.dnn', 'params', '.json' ]);
  1491. this.register('./flax', [ '.msgpack' ]);
  1492. this.register('./om', [ '.om', '.onnx', '.pb', '.engine' ]);
  1493. this.register('./nnabla', [ '.nntxt' ], [ '.nnp' ]);
  1494. this.register('./cambricon', [ '.cambricon' ]);
  1495. this.register('./message', [ '.json']);
  1496. }
  1497. register(id, factories, containers) {
  1498. for (const extension of factories) {
  1499. this._factories.push({ extension: extension, id: id });
  1500. this._extensions.add(extension);
  1501. }
  1502. for (const extension of containers || []) {
  1503. this._extensions.add(extension);
  1504. }
  1505. }
  1506. open(context) {
  1507. return this._openSignature(context).then((context) => {
  1508. const modelContext = new view.ModelContext(context);
  1509. /* eslint-disable consistent-return */
  1510. return this._openContext(modelContext).then((model) => {
  1511. if (model) {
  1512. return model;
  1513. }
  1514. const entries = modelContext.entries();
  1515. if (entries && entries.size > 0) {
  1516. return this._openEntries(entries).then((context) => {
  1517. if (context) {
  1518. return this._openContext(context);
  1519. }
  1520. this._unsupported(modelContext);
  1521. });
  1522. }
  1523. this._unsupported(modelContext);
  1524. });
  1525. /* eslint-enable consistent-return */
  1526. });
  1527. }
  1528. _unsupported(context) {
  1529. const identifier = context.identifier;
  1530. const extension = identifier.split('.').pop().toLowerCase();
  1531. const stream = context.stream;
  1532. for (const module of [ zip, tar, gzip ]) {
  1533. let archive = null;
  1534. try {
  1535. archive = module.Archive.open(stream);
  1536. }
  1537. catch (error) {
  1538. // continue regardless of error
  1539. }
  1540. if (archive) {
  1541. throw new view.Error("Archive contains no model files in '" + identifier + "'.", true);
  1542. }
  1543. }
  1544. const skip = () => {
  1545. const knownUnsupportedIdentifiers = new Set([
  1546. 'natives_blob.bin',
  1547. 'v8_context_snapshot.bin',
  1548. 'snapshot_blob.bin',
  1549. 'image_net_labels.json',
  1550. 'package.json',
  1551. 'models.json',
  1552. 'LICENSE.meta',
  1553. 'input_0.pb',
  1554. 'output_0.pb'
  1555. ]);
  1556. return knownUnsupportedIdentifiers.has(context.identifier);
  1557. };
  1558. const json = () => {
  1559. const obj = context.open('json');
  1560. if (obj) {
  1561. const formats = [
  1562. { name: 'Netron metadata', tags: [ '[].name', '[].schema' ] },
  1563. { name: 'Netron metadata', tags: [ '[].name', '[].attributes' ] },
  1564. { name: 'Netron metadata', tags: [ '[].name', '[].category' ] },
  1565. { name: 'Darkflow metadata', tags: [ 'net', 'type', 'model' ] },
  1566. { name: 'keras-yolo2 configuration', tags: [ 'model', 'train', 'valid' ] },
  1567. { name: 'Vulkan SwiftShader ICD manifest', tags: [ 'file_format_version', 'ICD' ] },
  1568. { name: 'DeepLearningExamples configuration', tags: [ 'attention_probs_dropout_prob', 'hidden_act', 'hidden_dropout_prob', 'hidden_size', ] },
  1569. { name: 'NuGet assets', tags: [ 'version', 'targets', 'packageFolders' ] },
  1570. { name: 'NuGet data', tags: [ 'format', 'restore', 'projects' ] },
  1571. { name: 'NPM package', tags: [ 'name', 'version', 'dependencies' ] },
  1572. { name: 'NetworkX adjacency_data', tags: [ 'directed', 'graph', 'nodes' ] },
  1573. { name: 'Waifu2x data', tags: [ 'name', 'arch_name', 'channels' ] },
  1574. { name: 'Waifu2x data', tags: [ '[].nInputPlane', '[].nOutputPlane', '[].weight', '[].bias' ] },
  1575. { name: 'Brain.js data', tags: [ 'type', 'sizes', 'layers' ] },
  1576. { name: 'Custom Vision metadata', tags: [ 'CustomVision.Metadata.Version' ] },
  1577. { name: 'W&B metadata', tags: [ 'program', 'host', 'executable' ] }
  1578. ];
  1579. const match = (obj, tag) => {
  1580. if (tag.startsWith('[].')) {
  1581. tag = tag.substring(3);
  1582. return (Array.isArray(obj) && obj.some((item) => Object.prototype.hasOwnProperty.call(item, tag)));
  1583. }
  1584. return Object.prototype.hasOwnProperty.call(obj, tag);
  1585. };
  1586. for (const format of formats) {
  1587. if (format.tags.every((tag) => match(obj, tag))) {
  1588. throw new view.Error('Invalid file content. File contains ' + format.name + '.', true);
  1589. }
  1590. }
  1591. const content = JSON.stringify(obj).substring(0, 100).replace(/\s/, '').substr(0, 48) + '...';
  1592. throw new view.Error("Unsupported JSON content '" + (content.length > 64 ? content.substring(0, 100) + '...' : content) + "' for extension '." + extension + "' in '" + identifier + "'.", !skip());
  1593. }
  1594. };
  1595. const pbtxt = () => {
  1596. const formats = [
  1597. { name: 'ImageNet LabelMap data', tags: [ 'entry', 'entry.target_class' ] },
  1598. { name: 'StringIntLabelMapProto data', tags: [ 'item', 'item.id', 'item.name' ] },
  1599. { name: 'caffe.LabelMap data', tags: [ 'item', 'item.name', 'item.label' ] },
  1600. { name: 'Triton Inference Server configuration', tags: [ 'name', 'platform', 'input', 'output' ] },
  1601. { name: 'TensorFlow OpList data', tags: [ 'op', 'op.name', 'op.input_arg' ] },
  1602. { name: 'vitis.ai.proto.DpuModelParamList data', tags: [ 'model', 'model.name', 'model.kernel' ] },
  1603. { name: 'object_detection.protos.DetectionModel data', tags: [ 'model', 'model.ssd' ] },
  1604. { name: 'object_detection.protos.DetectionModel data', tags: [ 'model', 'model.faster_rcnn' ] },
  1605. { name: 'tensorflow.CheckpointState data', tags: [ 'model_checkpoint_path', 'all_model_checkpoint_paths' ] },
  1606. { name: 'apollo.perception.camera.traffic_light.detection.DetectionParam data', tags: [ 'min_crop_size', 'crop_method' ] },
  1607. { name: 'tidl_meta_arch.TIDLMetaArch data', tags: [ 'caffe_ssd' ] }, // https://github.com/TexasInstruments/edgeai-mmdetection/blob/master/mmdet/utils/proto/mmdet_meta_arch.proto
  1608. { name: 'tidl_meta_arch.TIDLMetaArch data', tags: [ 'tf_od_api_ssd' ] },
  1609. { name: 'tidl_meta_arch.TIDLMetaArch data', tags: [ 'tidl_ssd' ] },
  1610. { name: 'tidl_meta_arch.TIDLMetaArch data', tags: [ 'tidl_faster_rcnn' ] },
  1611. { name: 'tidl_meta_arch.TIDLMetaArch data', tags: [ 'tidl_yolo' ] },
  1612. { name: 'tidl_meta_arch.TIDLMetaArch data', tags: [ 'tidl_retinanet' ] },
  1613. { name: 'domi.InsertNewOps data', tags: [ 'aipp_op' ] } // https://github.com/Ascend/parser/blob/development/parser/proto/insert_op.proto
  1614. ];
  1615. const tags = context.tags('pbtxt');
  1616. if (tags.size > 0) {
  1617. for (const format of formats) {
  1618. if (format.tags.every((tag) => tags.has(tag))) {
  1619. throw new view.Error('Invalid file content. File contains ' + format.name + '.', true);
  1620. }
  1621. }
  1622. const entries = [];
  1623. entries.push(...Array.from(tags).filter((pair) => pair[0].toString().indexOf('.') === -1));
  1624. entries.push(...Array.from(tags).filter((pair) => pair[0].toString().indexOf('.') !== -1));
  1625. const content = entries.map((pair) => pair[1] === true ? pair[0] : pair[0] + ':' + JSON.stringify(pair[1])).join(',');
  1626. throw new view.Error("Unsupported Protocol Buffers text content '" + (content.length > 64 ? content.substring(0, 100) + '...' : content) + "' for extension '." + extension + "' in '" + identifier + "'.", !skip());
  1627. }
  1628. };
  1629. const pb = () => {
  1630. const tags = context.tags('pb+');
  1631. if (Object.keys(tags).length > 0) {
  1632. const formats = [
  1633. { name: 'sentencepiece.ModelProto data', tags: [[1,[[1,2],[2,5],[3,0]]],[2,[[1,2],[2,2],[3,0],[4,0],[5,2],[6,0],[7,2],[10,5],[16,0],[40,0],[41,0],[42,0],[43,0]]],[3,[]],[4,[]],[5,[]]] },
  1634. { name: 'mediapipe.BoxDetectorIndex data', tags: [[1,[[1,[[1,[[1,5],[2,5],[3,5],[4,5],[6,0],[7,5],[8,5],[10,5],[11,0],[12,0]]],[2,5],[3,[]]]],[2,false],[3,false],[4,false],[5,false]]],[2,false],[3,false]] },
  1635. { name: 'third_party.tensorflow.python.keras.protobuf.SavedMetadata data', tags: [[1,[[1,[[1,0],[2,0]]],[2,0],[3,2],[4,2],[5,2]]]] },
  1636. { name: 'pblczero.Net data', tags: [[1,5],[2,2],[3,[[1,0],[2,0],[3,0]],[10,[[1,[]],[2,[]],[3,[]],[4,[]],[5,[]],[6,[]]]],[11,[]]]] } // https://github.com/LeelaChessZero/lczero-common/blob/master/proto/net.proto
  1637. ];
  1638. const match = (tags, schema) => {
  1639. for (const pair of schema) {
  1640. const key = pair[0];
  1641. const inner = pair[1];
  1642. const value = tags[key];
  1643. if (value === undefined) {
  1644. continue;
  1645. }
  1646. if (inner === false) {
  1647. return false;
  1648. }
  1649. if (Array.isArray(inner)) {
  1650. if (typeof value !== 'object' || !match(value, inner)) {
  1651. return false;
  1652. }
  1653. }
  1654. else if (inner !== value) {
  1655. if (inner === 2 && !Array.isArray(value) && Object(value) === (value) && Object.keys(value).length === 0) {
  1656. return true;
  1657. }
  1658. return false;
  1659. }
  1660. }
  1661. return true;
  1662. };
  1663. const tags = context.tags('pb+');
  1664. for (const format of formats) {
  1665. if (match(tags, format.tags)) {
  1666. throw new view.Error('Invalid file content. File contains ' + format.name + '.', true);
  1667. }
  1668. }
  1669. const format = (tags) => {
  1670. const content = Object.entries(tags).map((pair) => {
  1671. const key = pair[0];
  1672. const value = pair[1];
  1673. return key.toString() + ':' + (Object(value) === value ? '{' + format(value) + '}' : value.toString());
  1674. });
  1675. return content.join(',');
  1676. };
  1677. const content = format(tags);
  1678. throw new view.Error("Unsupported Protocol Buffers content '" + (content.length > 64 ? content.substring(0, 100) + '...' : content) + "' for extension '." + extension + "' in '" + identifier + "'.", !skip());
  1679. }
  1680. };
  1681. const flatbuffers = () => {
  1682. const tags = context.tags('flatbuffers');
  1683. if (tags.has('file_identifier')) {
  1684. const file_identifier = tags.get('file_identifier');
  1685. const formats = [
  1686. { name: 'onnxruntime.experimental.fbs.InferenceSession data', identifier: 'ORTM' },
  1687. { name: 'tflite.Model data', identifier: 'TFL3' },
  1688. { name: 'torch.jit.mobile.serialization.Module data', identifier: 'PTMF' }, // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/serialization/mobile_bytecode.fbs
  1689. { name: 'FlatBuffers ENNC data', identifier: 'ENNC' },
  1690. ];
  1691. for (const format of formats) {
  1692. if (file_identifier === format.identifier) {
  1693. throw new view.Error('Invalid file content. File contains ' + format.name + '.', true);
  1694. }
  1695. }
  1696. }
  1697. };
  1698. const xml = () => {
  1699. const tags = context.tags('xml');
  1700. if (tags.size > 0) {
  1701. const formats = [
  1702. { name: 'OpenCV storage data', tags: [ 'opencv_storage' ] },
  1703. { name: 'XHTML markup', tags: [ 'http://www.w3.org/1999/xhtml:html' ]}
  1704. ];
  1705. for (const format of formats) {
  1706. if (format.tags.some((tag) => tags.has(tag))) {
  1707. throw new view.Error('Invalid file content. File contains ' + format.name + '.', true);
  1708. }
  1709. }
  1710. throw new view.Error("Unsupported XML content '" + tags.keys().next().value + "' in '" + identifier + "'.", !skip());
  1711. }
  1712. };
  1713. const unknown = () => {
  1714. if (stream) {
  1715. stream.seek(0);
  1716. const buffer = stream.peek(Math.min(16, stream.length));
  1717. const bytes = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
  1718. const content = stream.length > 268435456 ? '(' + bytes + ') [' + stream.length.toString() + ']': '(' + bytes + ')';
  1719. throw new view.Error("Unsupported file content " + content + " for extension '." + extension + "' in '" + identifier + "'.", !skip());
  1720. }
  1721. throw new view.Error("Unsupported file directory in '" + identifier + "'.", !skip());
  1722. };
  1723. json();
  1724. pbtxt();
  1725. pb();
  1726. flatbuffers();
  1727. xml();
  1728. unknown();
  1729. }
  1730. _openContext(context) {
  1731. const modules = this._filter(context).filter((module) => module && module.length > 0);
  1732. const errors = [];
  1733. let success = false;
  1734. const nextModule = () => {
  1735. if (modules.length > 0) {
  1736. const id = modules.shift();
  1737. return this._host.require(id).then((module) => {
  1738. const updateErrorContext = (error, context) => {
  1739. const content = " in '" + context.identifier + "'.";
  1740. if (error && typeof error.message === 'string' && !error.message.endsWith(content) && (error.context === undefined || error.context === true)) {
  1741. error.message = error.message.replace(/\.$/, '') + content;
  1742. }
  1743. };
  1744. if (!module.ModelFactory) {
  1745. throw new view.Error("Failed to load module '" + id + "'.");
  1746. }
  1747. const modelFactory = new module.ModelFactory();
  1748. let match = undefined;
  1749. try {
  1750. match = modelFactory.match(context);
  1751. if (!match) {
  1752. return nextModule();
  1753. }
  1754. }
  1755. catch (error) {
  1756. updateErrorContext(error, context);
  1757. return Promise.reject(error);
  1758. }
  1759. success = true;
  1760. return modelFactory.open(context, match).then((model) => {
  1761. if (!model.identifier) {
  1762. model.identifier = context.identifier;
  1763. }
  1764. return model;
  1765. }).catch((error) => {
  1766. if (context.stream && context.stream.position !== 0) {
  1767. context.stream.seek(0);
  1768. }
  1769. updateErrorContext(error, context);
  1770. errors.push(error);
  1771. return nextModule();
  1772. });
  1773. });
  1774. }
  1775. if (success) {
  1776. if (errors.length === 1) {
  1777. const error = errors[0];
  1778. return Promise.reject(error);
  1779. }
  1780. return Promise.reject(new view.Error(errors.map((err) => err.message).join('\n')));
  1781. }
  1782. return Promise.resolve(null);
  1783. };
  1784. return nextModule();
  1785. }
  1786. _openEntries(entries) {
  1787. try {
  1788. const rootFolder = (files) => {
  1789. const map = files.map((file) => file.split('/').slice(0, -1));
  1790. const at = index => list => list[index];
  1791. const rotate = list => list.length === 0 ? [] : list[0].map((item, index) => list.map(at(index)));
  1792. const equals = list => list.every((item) => item === list[0]);
  1793. const folder = rotate(map).filter(equals).map(at(0)).join('/');
  1794. return folder.length === 0 ? folder : folder + '/';
  1795. };
  1796. const filter = (queue) => {
  1797. let matches = [];
  1798. const nextEntry = () => {
  1799. if (queue.length > 0) {
  1800. const entry = queue.shift();
  1801. const context = new view.ModelContext(new view.EntryContext(this._host, null, folder, entry.name, entry.stream));
  1802. let modules = this._filter(context);
  1803. const nextModule = () => {
  1804. if (modules.length > 0) {
  1805. const id = modules.shift();
  1806. return this._host.require(id).then((module) => {
  1807. if (!module.ModelFactory) {
  1808. throw new view.ArchiveError("Failed to load module '" + id + "'.", null);
  1809. }
  1810. const factory = new module.ModelFactory();
  1811. if (factory.match(context)) {
  1812. matches.push(entry);
  1813. modules = [];
  1814. }
  1815. return nextModule();
  1816. });
  1817. }
  1818. return nextEntry();
  1819. };
  1820. return nextModule();
  1821. }
  1822. if (matches.length === 0) {
  1823. return Promise.resolve(null);
  1824. }
  1825. // MXNet
  1826. if (matches.length === 2 &&
  1827. matches.some((e) => e.name.toLowerCase().endsWith('.params')) &&
  1828. matches.some((e) => e.name.toLowerCase().endsWith('-symbol.json'))) {
  1829. matches = matches.filter((e) => e.name.toLowerCase().endsWith('.params'));
  1830. }
  1831. // TensorFlow.js
  1832. if (matches.length > 0 &&
  1833. matches.some((e) => e.name.toLowerCase().endsWith('.bin')) &&
  1834. matches.some((e) => e.name.toLowerCase().endsWith('.json'))) {
  1835. matches = matches.filter((e) => e.name.toLowerCase().endsWith('.json'));
  1836. }
  1837. // ncnn
  1838. if (matches.length > 0 &&
  1839. matches.some((e) => e.name.toLowerCase().endsWith('.bin')) &&
  1840. matches.some((e) => e.name.toLowerCase().endsWith('.param'))) {
  1841. matches = matches.filter((e) => e.name.toLowerCase().endsWith('.param'));
  1842. }
  1843. // ncnn
  1844. if (matches.length > 0 &&
  1845. matches.some((e) => e.name.toLowerCase().endsWith('.bin')) &&
  1846. matches.some((e) => e.name.toLowerCase().endsWith('.param.bin'))) {
  1847. matches = matches.filter((e) => e.name.toLowerCase().endsWith('.param.bin'));
  1848. }
  1849. // Paddle
  1850. if (matches.length > 0 &&
  1851. matches.some((e) => e.name.toLowerCase().endsWith('.pdmodel')) &&
  1852. (matches.some((e) => e.name.toLowerCase().endsWith('.pdparams')) ||
  1853. matches.some((e) => e.name.toLowerCase().endsWith('.pdopt')) ||
  1854. matches.some((e) => e.name.toLowerCase().endsWith('.pdiparams')))) {
  1855. matches = matches.filter((e) => e.name.toLowerCase().endsWith('.pdmodel'));
  1856. }
  1857. // Paddle Lite
  1858. if (matches.length > 0 &&
  1859. matches.some((e) => e.name.toLowerCase().split('/').pop() === '__model__.nb') &&
  1860. matches.some((e) => e.name.toLowerCase().split('/').pop() === 'param.nb')) {
  1861. matches = matches.filter((e) => e.name.toLowerCase().split('/').pop() == '__model__.nb');
  1862. }
  1863. // TensorFlow Bundle
  1864. if (matches.length > 1 &&
  1865. matches.some((e) => e.name.toLowerCase().endsWith('.data-00000-of-00001'))) {
  1866. matches = matches.filter((e) => !e.name.toLowerCase().endsWith('.data-00000-of-00001'));
  1867. }
  1868. // TensorFlow SavedModel
  1869. if (matches.length === 2 &&
  1870. matches.some((e) => e.name.toLowerCase().split('/').pop() === 'keras_metadata.pb')) {
  1871. matches = matches.filter((e) => e.name.toLowerCase().split('/').pop() !== 'keras_metadata.pb');
  1872. }
  1873. if (matches.length > 1) {
  1874. return Promise.reject(new view.ArchiveError('Archive contains multiple model files.'));
  1875. }
  1876. const match = matches.shift();
  1877. return Promise.resolve(new view.ModelContext(new view.EntryContext(this._host, entries, folder, match.name, match.stream)));
  1878. };
  1879. return nextEntry();
  1880. };
  1881. const list = Array.from(entries).map((entry) => {
  1882. return { name: entry[0], stream: entry[1] };
  1883. });
  1884. const files = list.filter((entry) => {
  1885. if (entry.name.endsWith('/')) {
  1886. return false;
  1887. }
  1888. if (entry.name.split('/').pop().startsWith('.')) {
  1889. return false;
  1890. }
  1891. if (!entry.name.startsWith('./') && entry.name.startsWith('.')) {
  1892. return false;
  1893. }
  1894. return true;
  1895. });
  1896. const folder = rootFolder(files.map((entry) => entry.name));
  1897. const queue = files.slice(0).filter((entry) => entry.name.substring(folder.length).indexOf('/') < 0);
  1898. return filter(queue).then((context) => {
  1899. if (context) {
  1900. return Promise.resolve(context);
  1901. }
  1902. const queue = files.slice(0).filter((entry) => entry.name.substring(folder.length).indexOf('/') >= 0);
  1903. return filter(queue);
  1904. });
  1905. }
  1906. catch (error) {
  1907. return Promise.reject(new view.ArchiveError(error.message));
  1908. }
  1909. }
  1910. accept(identifier) {
  1911. const extension = identifier.indexOf('.') === -1 ? '' : identifier.split('.').pop().toLowerCase();
  1912. identifier = identifier.toLowerCase().split('/').pop();
  1913. for (const extension of this._extensions) {
  1914. if ((typeof extension === 'string' && identifier.endsWith(extension)) || (extension instanceof RegExp && extension.exec(identifier))) {
  1915. this._host.event('File', 'Accept', extension, 1);
  1916. return true;
  1917. }
  1918. }
  1919. this._host.event('File', 'Reject', extension, 1);
  1920. return false;
  1921. }
  1922. _filter(context) {
  1923. const identifier = context.identifier.toLowerCase().split('/').pop();
  1924. const list = this._factories.filter((entry) =>
  1925. (typeof entry.extension === 'string' && identifier.endsWith(entry.extension)) ||
  1926. (entry.extension instanceof RegExp && entry.extension.exec(identifier)));
  1927. return Array.from(new Set(list.map((entry) => entry.id)));
  1928. }
  1929. _openSignature(context) {
  1930. const stream = context.stream;
  1931. if (stream) {
  1932. let empty = true;
  1933. let position = 0;
  1934. while (empty && position < stream.length) {
  1935. const buffer = stream.read(Math.min(4096, stream.length - position));
  1936. position += buffer.length;
  1937. if (!buffer.every((value) => value === 0x00)) {
  1938. empty = false;
  1939. break;
  1940. }
  1941. }
  1942. stream.seek(0);
  1943. if (empty) {
  1944. return Promise.reject(new view.Error('File has no content.', true));
  1945. }
  1946. /* eslint-disable no-control-regex */
  1947. const entries = [
  1948. { name: 'ELF executable', value: /^\x7FELF/ },
  1949. { name: 'PNG image', value: /^\x89PNG/ },
  1950. { name: 'Git LFS header', value: /^version https:\/\/git-lfs.github.com/ },
  1951. { name: 'Git LFS header', value: /^\s*oid sha256:/ },
  1952. { name: 'HTML markup', value: /^\s*<html>/ },
  1953. { name: 'HTML markup', value: /^\s*<!doctype\s*html>/ },
  1954. { name: 'HTML markup', value: /^\s*<!DOCTYPE\s*html>/ },
  1955. { name: 'HTML markup', value: /^\s*<!DOCTYPE\s*HTML>/ },
  1956. { name: 'HTML markup', value: /^\s*<!DOCTYPE\s*HTML\s+(PUBLIC|SYSTEM)?/ },
  1957. { name: 'Unity metadata', value: /^fileFormatVersion:/ },
  1958. { name: 'Python source code', value: /^\s*import[ ]+(os|sys|types|torch|argparse|onnx|numpy|tensorflow)(,|;|\s)/ },
  1959. { name: 'Python source code', value: /^\s*import[ ]+([a-z])+[ ]+as[ ]+/ },
  1960. { name: 'Python source code', value: /^\s*from[ ]+(torch)[ ]+import[ ]+/ },
  1961. { name: 'Python source code', value: /^\s*from[ ]+(keras)[ ]+import[ ]+/ },
  1962. { name: 'Bash script', value: /^#!\/usr\/bin\/env\s/ },
  1963. { name: 'Bash script', value: /^#!\/bin\/bash\s/ },
  1964. { name: 'TSD header', value: /^%TSD-Header-###%/ },
  1965. { name: 'AppleDouble data', value: /^\x00\x05\x16\x07/ },
  1966. { name: 'TensorFlow Hub module', value: /^\x08\x03$/, identifier: 'tfhub_module.pb' },
  1967. { name: 'ViSQOL model', value: /^svm_type\snu_svr/ },
  1968. { name: 'SenseTime model', value: /^STEF/ }
  1969. ];
  1970. /* eslint-enable no-control-regex */
  1971. const buffer = stream.peek(Math.min(4096, stream.length));
  1972. const content = String.fromCharCode.apply(null, buffer);
  1973. for (const entry of entries) {
  1974. if (content.match(entry.value) && (!entry.identifier || entry.identifier === context.identifier)) {
  1975. return Promise.reject(new view.Error('Invalid file content. File contains ' + entry.name + '.', true));
  1976. }
  1977. }
  1978. }
  1979. return Promise.resolve(context);
  1980. }
  1981. };
  1982. view.Error = class extends Error {
  1983. constructor(message, telemetry) {
  1984. super(message);
  1985. this.name = 'Error loading model.';
  1986. this.telemetry = telemetry;
  1987. this.stack = undefined;
  1988. }
  1989. };
  1990. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1991. module.exports.View = view.View;
  1992. module.exports.ModelFactoryService = view.ModelFactoryService;
  1993. }