onnx.js 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640
  1. /* jshint esversion: 6 */
  2. var onnx = onnx || {};
  3. var protobuf = protobuf || require('./protobuf');
  4. var flatbuffers = flatbuffers || require('./flatbuffers');
  5. onnx.ModelFactory = class {
  6. match(context) {
  7. const identifier = context.identifier;
  8. const extension = identifier.split('.').pop().toLowerCase();
  9. if (identifier.endsWith('saved_model.pb') || identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb')) {
  10. return undefined;
  11. }
  12. if (identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
  13. identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
  14. return undefined;
  15. }
  16. let tags = context.tags('pb');
  17. if (tags.size > 0) {
  18. if (tags.size === 1 && tags.get(1) === 2) {
  19. const tags = context.tags('pb+');
  20. const match = (tags, schema) => {
  21. for (const pair of schema) {
  22. const key = pair[0];
  23. const inner = pair[1];
  24. if (tags[key] === undefined) {
  25. continue;
  26. }
  27. else if (inner === false) {
  28. return false;
  29. }
  30. if (Array.isArray(inner)) {
  31. const value = tags[key];
  32. if (typeof value !== 'object' || !match(value, inner)) {
  33. return false;
  34. }
  35. }
  36. else if (inner !== tags[key]) {
  37. return false;
  38. }
  39. }
  40. return true;
  41. };
  42. // mediapipe.BoxDetectorIndex
  43. if (match(tags, [[1,[[1,[[1,[[1,5],[2,5],[3,5],[4,5],[6,0],[7,5],[8,5],[10,5],[11,0],[12,0]]],[2,5],[3,[]]]],[2,false],[3,false],[4,false],[5,false]]],[2,false],[3,false]] )) {
  44. return undefined;
  45. }
  46. // third_party.tensorflow.python.keras.protobuf.SavedMetadata
  47. if (match(tags, [[1,[[1,[[1,0],[2,0]]],[2,0],[3,2],[4,2],[5,2]]]])) {
  48. return undefined;
  49. }
  50. }
  51. if (Array.from(tags.keys()).every((tag) => tag <= 100) &&
  52. Array.from(tags.values()).every((type) => type < 5)) {
  53. // TensorProto
  54. if (tags.get(1) === 0 && tags.get(2) === 0 && tags.get(9) === 2) {
  55. const schema = [[1,0],[2,0],[4,2],[5,2],[7,2],[8,2],[9,2]];
  56. if (schema.every((pair) => !tags.has(pair[0]) || tags.get(pair[0]) === pair[1])) {
  57. return 'onnx.pb.TensorProto';
  58. }
  59. }
  60. // GraphProto
  61. if (tags.get(1) === 2) {
  62. const schema = [[1,2],[2,2],[3,2],[4,2],[5,2],[6,0],[7,0],[8,2],[9,2],[10,2],[11,2],[12,2],[13,2],[14,2]];
  63. if (schema.every((pair) => !tags.has(pair[0]) || tags.get(pair[0]) === pair[1])) {
  64. const decode = (buffer, value) => {
  65. const reader = protobuf.BinaryReader.open(buffer);
  66. const length = reader.length;
  67. while (reader.position < length) {
  68. const tag = reader.uint32();
  69. const number = tag >>> 3;
  70. const type = tag & 7;
  71. if (value === number) {
  72. return type === 2 ? reader.bytes() : null;
  73. }
  74. else {
  75. reader.skipType(type);
  76. }
  77. }
  78. return null;
  79. };
  80. const stream = context.stream;
  81. const buffer = stream.peek();
  82. const nodeBuffer = decode(buffer, 1);
  83. if (nodeBuffer) {
  84. const nameBuffer = decode(nodeBuffer, 4);
  85. if (nameBuffer && nameBuffer.every((c) => c > 0x20 && c < 0x7f)) {
  86. return 'onnx.pb.GraphProto';
  87. }
  88. }
  89. }
  90. }
  91. // ModelProto
  92. if (tags.get(7) === 2) {
  93. const schema = [[1,0],[2,2],[3,2],[4,2][5,0],[6,2],[7,2],[8,2],[14,2],[20,2]];
  94. if (schema.every((pair) => !tags.has(pair[0]) || tags.get(pair[0]) === pair[1])) {
  95. return 'onnx.pb.ModelProto';
  96. }
  97. }
  98. }
  99. }
  100. const stream = context.stream;
  101. if (stream.length > 5) {
  102. const buffer = stream.peek(Math.min(stream.length, 32));
  103. if (buffer[0] === 0x08 && buffer[1] < 0x0A && buffer[2] === 0x12) {
  104. const producers = [
  105. 'backend-test', 'BrainwaveCompiler',
  106. 'CNTK',
  107. 'keras2onnx', 'Kneron', 'kneron_formatter', 'kneron_kl530_test_case',
  108. 'darknet to ONNX example',
  109. 'htshinichi',
  110. 'MATLAB Deep Learning Toolbox Converter for ONNX Model Format', 'ML.NET', 'MVTec Software',
  111. 'onnx-caffe2', 'onnx-example', 'onnx.quantize', 'onnx.utils.extract_model', 'OnnxMLTools', 'onnx_test', 'onnxruntime-tools', 'onnxruntime.transformers',
  112. 'PaddlePaddle', 'pytorch',
  113. 'sclblonnx', 'skl2onnx',
  114. 'Tencent YouTu', 'tf2onnx', 'tflite2onnx',
  115. 'WinMLTools'
  116. ];
  117. if (producers.some((producer) => Array.from(producer).every((ch, index) => index + 4 < buffer.length && ch.charCodeAt(0) === buffer[index + 4]))) {
  118. return 'onnx.pb.ModelProto';
  119. }
  120. }
  121. }
  122. tags = context.tags('pbtxt');
  123. if (tags.has('ir_version')) {
  124. return 'onnx.pbtxt.ModelProto';
  125. }
  126. if (tags.has('graph') && extension !== 'model') {
  127. return 'onnx.pbtxt.ModelProto';
  128. }
  129. if (context.tags('flatbuffers').get('file_identifier') === 'ORTM') {
  130. return 'onnx.flatbuffers';
  131. }
  132. return undefined;
  133. }
  134. open(context, match) {
  135. const open = (model, format) => {
  136. return onnx.Metadata.open(context).then((metadata) => {
  137. return new onnx.Model(metadata, model, format);
  138. });
  139. };
  140. switch (match) {
  141. case 'onnx.pbtxt.ModelProto':
  142. return context.require('./onnx-proto').then(() => {
  143. try {
  144. onnx.proto = protobuf.get('onnx').onnx;
  145. const stream = context.stream;
  146. const reader = protobuf.TextReader.open(stream);
  147. const model = onnx.proto.ModelProto.decodeText(reader);
  148. const format = 'ONNX' + (model.ir_version ? ' v' + model.ir_version.toString() : '');
  149. return open(model, format);
  150. }
  151. catch (error) {
  152. const message = error && error.message ? error.message : error.toString();
  153. throw new onnx.Error('File text format is not onnx.ModelProto (' + message.replace(/\.$/, '') + ').');
  154. }
  155. });
  156. case 'onnx.pb.TensorProto':
  157. return context.require('./onnx-proto').then(() => {
  158. // TensorProto
  159. // input_0.pb, output_0.pb
  160. try {
  161. onnx.proto = protobuf.get('onnx').onnx;
  162. const stream = context.stream;
  163. const reader = protobuf.BinaryReader.open(stream);
  164. const tensor = onnx.proto.TensorProto.decode(reader);
  165. tensor.name = tensor.name || context.identifier;
  166. const model = new onnx.proto.ModelProto();
  167. model.graph = new onnx.proto.GraphProto();
  168. model.graph.initializer = [ tensor ];
  169. model.graph.value_info = [ new onnx.proto.ValueInfoProto() ];
  170. model.graph.value_info[0].name = tensor.name;
  171. model.graph.node = [ new onnx.proto.NodeProto() ];
  172. model.graph.node[0].op_type = 'Constant';
  173. model.graph.node[0].attribute = [ new onnx.proto.AttributeProto() ];
  174. model.graph.node[0].attribute[0].name = 'value';
  175. model.graph.node[0].attribute[0].t = tensor;
  176. const format = 'ONNX Tensor';
  177. return open(model, format);
  178. }
  179. catch (error) {
  180. const message = error && error.message ? error.message : error.toString();
  181. throw new onnx.Error('File format is not onnx.TensorProto (' + message.replace(/\.$/, '') + ').');
  182. }
  183. });
  184. case 'onnx.pb.GraphProto':
  185. return context.require('./onnx-proto').then(() => {
  186. // GraphProto
  187. try {
  188. onnx.proto = protobuf.get('onnx').onnx;
  189. const stream = context.stream;
  190. const reader = protobuf.BinaryReader.open(stream);
  191. const model = new onnx.proto.ModelProto();
  192. model.graph = onnx.proto.GraphProto.decode(reader);
  193. const format = 'ONNX';
  194. return open(model, format);
  195. }
  196. catch (error) {
  197. const message = error && error.message ? error.message : error.toString();
  198. throw new onnx.Error('File format is not onnx.GraphProto (' + message.replace(/\.$/, '') + ').');
  199. }
  200. });
  201. case 'onnx.pb.ModelProto':
  202. return context.require('./onnx-proto').then(() => {
  203. // ModelProto
  204. try {
  205. onnx.proto = protobuf.get('onnx').onnx;
  206. const stream = context.stream;
  207. const reader = protobuf.BinaryReader.open(stream);
  208. const model = onnx.proto.ModelProto.decode(reader);
  209. const format = 'ONNX' + (model.ir_version ? ' v' + model.ir_version.toString() : '');
  210. return open(model, format);
  211. }
  212. catch (error) {
  213. const message = error && error.message ? error.message : error.toString();
  214. throw new onnx.Error('File format is not onnx.ModelProto (' + message.replace(/\.$/, '') + ').');
  215. }
  216. });
  217. case 'onnx.flatbuffers': {
  218. return context.require('./ort-schema').then((/* schema */) => {
  219. try {
  220. onnx.schema = flatbuffers.get('ort').onnxruntime.experimental.fbs;
  221. const stream = context.stream;
  222. const reader = flatbuffers.BinaryReader.open(stream);
  223. const session = onnx.schema.InferenceSession.create(reader);
  224. const model = session.model;
  225. const graph = model.graph;
  226. graph.node = graph.nodes;
  227. graph.doc_string = model.graph_doc_string;
  228. graph.value_info = graph.node_args;
  229. graph.input = graph.inputs.map((input) => {
  230. return { name: input };
  231. });
  232. graph.output = graph.outputs.map((output) => {
  233. return { name: output };
  234. });
  235. graph.initializer = graph.initializers.map((tensor) => {
  236. tensor.data_location = onnx.DataLocation.DEFAULT;
  237. return tensor;
  238. });
  239. graph.sparse_initializer = graph.sparse_initializers.map((tensor) => {
  240. tensor.values.data_location = onnx.DataLocation.DEFAULT;
  241. tensor.indices.data_location = onnx.DataLocation.DEFAULT;
  242. return tensor;
  243. });
  244. delete graph.nodes;
  245. delete graph.node_args;
  246. delete graph.inputs;
  247. delete graph.outputs;
  248. delete graph.initializers;
  249. delete graph.sparse_initializers;
  250. delete model.graph_doc_string;
  251. for (const node of graph.node) {
  252. node.input = node.inputs;
  253. node.output = node.outputs;
  254. node.attribute = node.attributes;
  255. delete node.inputs;
  256. delete node.outputs;
  257. delete node.attributes;
  258. }
  259. const format = 'ONNX Runtime' + (model.ir_version ? ' v' + model.ir_version.toString() : '');
  260. return open(model, format);
  261. }
  262. catch (error) {
  263. const message = error && error.message ? error.message : error.toString();
  264. throw new onnx.Error('File format is not ort.Model (' + message.replace(/\.$/, '') + ').');
  265. }
  266. });
  267. }
  268. default: {
  269. throw new onnx.Error("Unknown ONNX format '" + match + "'.");
  270. }
  271. }
  272. }
  273. };
  274. onnx.Model = class {
  275. constructor(metadata, model, format) {
  276. this._graphs = [];
  277. this._format = format;
  278. this._producerName = model.producer_name;
  279. this._producerVersion = model.producer_version;
  280. this._domain = model.domain;
  281. this._modelVersion = model.model_version;
  282. this._description = model.doc_string;
  283. this._metadata = [];
  284. this._imports = null;
  285. const imports = new Map();
  286. if (model.opset_import && model.opset_import.length > 0) {
  287. for (const opset_import of model.opset_import) {
  288. const domain = opset_import.domain || 'ai.onnx';
  289. const version = opset_import.version ? opset_import.version.toNumber() : 0;
  290. if (!imports.has(domain) || imports.get(domain) > version) {
  291. imports.set(domain, version);
  292. }
  293. }
  294. this._imports = Array.from(imports).map((pair) => pair[0] + ' v' + pair[1].toString());
  295. }
  296. if (imports.size == 0) {
  297. imports.set('ai.onnx', 1);
  298. imports.set('ai.onnx.ml', 1);
  299. }
  300. let imageFormat = '';
  301. if (model.metadata_props) {
  302. const imageMetadata = {};
  303. for (const metadata_prop of model.metadata_props) {
  304. switch (metadata_prop.key) {
  305. case 'author':
  306. this._author = metadata_prop.value;
  307. break;
  308. case 'company':
  309. this._company = metadata_prop.value;
  310. break;
  311. case 'converted_from':
  312. this._converted_from = metadata_prop.value;
  313. break;
  314. case 'license':
  315. this._license = metadata_prop.value;
  316. break;
  317. case 'license_url':
  318. this._licenseUrl = metadata_prop.value;
  319. break;
  320. case 'Image.BitmapPixelFormat':
  321. case 'Image.ColorSpaceGamma':
  322. case 'Image.NominalPixelRange':
  323. imageMetadata[metadata_prop.key] = metadata_prop.value;
  324. break;
  325. default:
  326. this._metadata.push({ name: metadata_prop.key, value: metadata_prop.value});
  327. break;
  328. }
  329. }
  330. imageFormat = [ imageMetadata['Image.BitmapPixelFormat'], imageMetadata['Image.ColorSpaceGamma'], imageMetadata['Image.NominalPixelRange'] ].filter((item) => item);
  331. }
  332. this._graphs = [];
  333. if (model && model.graph) {
  334. let key = 1000;
  335. const context = {};
  336. context.metadata = new onnx.GraphMetadata(metadata, imports);
  337. context.imageFormat = imageFormat;
  338. for (const func of model.functions || []) {
  339. context.metadata.add(new onnx.Function(context, func));
  340. }
  341. context.graphs = new Map();
  342. context.graph = function(graph) {
  343. graph.key = graph.key || (key++).toString();
  344. if (!this.graphs.has(graph.key)) {
  345. this.graphs.set(graph.key, new onnx.Graph(this, graph));
  346. }
  347. return this.graphs.get(graph.key);
  348. };
  349. const graphs = [ model.graph ];
  350. while (graphs.length > 0) {
  351. const graph = graphs.shift();
  352. this._graphs.push(context.graph(graph));
  353. for (const node of graph.node || []) {
  354. for (const attribute of node.attribute || []) {
  355. if (attribute.g) {
  356. graphs.push(attribute.g);
  357. }
  358. else if (attribute.graphs && attribute.graphs.length > 0) {
  359. graphs.push(...attribute.graphs);
  360. }
  361. }
  362. }
  363. }
  364. }
  365. }
  366. get format() {
  367. return this._format;
  368. }
  369. get imports() {
  370. return this._imports;
  371. }
  372. get producer() {
  373. const producer = [];
  374. if (this._producerName) {
  375. producer.push(this._producerName);
  376. }
  377. if (this._producerVersion && this._producerVersion.length > 0) {
  378. producer.push(this._producerVersion);
  379. }
  380. if (producer.length > 0) {
  381. return producer.join(' ');
  382. }
  383. return null;
  384. }
  385. get domain() {
  386. return this._domain || null;
  387. }
  388. get description() {
  389. return this._description || null;
  390. }
  391. get author() {
  392. return this._author || null;
  393. }
  394. get company() {
  395. return this._company || null;
  396. }
  397. get source() {
  398. return this._converted_from || null;
  399. }
  400. get license() {
  401. const license = [];
  402. if (this._license && this._license.length > 0) {
  403. license.push(this._license);
  404. }
  405. if (this._licenseUrl && this._licenseUrl.length > 0) {
  406. license.push('<a href=\'' + this._licenseUrl + '\'>' + this._licenseUrl + '</a>');
  407. }
  408. if (license.length > 0) {
  409. return license;
  410. }
  411. return null;
  412. }
  413. get metadata() {
  414. return this._metadata;
  415. }
  416. get graphs() {
  417. return this._graphs;
  418. }
  419. };
  420. onnx.Graph = class {
  421. constructor(context, graph) {
  422. this._node = '';
  423. this._description = '';
  424. this._nodes = [];
  425. this._inputs = [];
  426. this._outputs = [];
  427. this._name = graph.name || null;
  428. this._description = graph.doc_string || '';
  429. const tensors = onnx.Utility.createTensors(graph.node);
  430. for (const initializer of graph.initializer) {
  431. const tensor = tensors.map(initializer.name);
  432. tensor.initializer = new onnx.Tensor(initializer, 'Initializer');
  433. }
  434. for (const sparse_initializer of graph.sparse_initializer) {
  435. const tensor = tensors.map(sparse_initializer.values.name);
  436. tensor.initializer = new onnx.Tensor(sparse_initializer, 'Sparse Initializer');
  437. }
  438. for (const tensor_annotation of graph.quantization_annotation || []) {
  439. const tensor = tensors.map(tensor_annotation.tensor_name);
  440. const annotation = {};
  441. for (const pair of tensor_annotation.quant_parameter_tensor_names) {
  442. annotation[pair.key] = pair.value;
  443. }
  444. tensor.annotation = annotation;
  445. }
  446. for (const valueInfo of graph.value_info) {
  447. const tensor = tensors.map(valueInfo.name);
  448. tensor.type = onnx.Utility.formatType(valueInfo.type, context.imageFormat);
  449. tensor.description = valueInfo.doc_string;
  450. }
  451. graph.input = graph.input.map((valueInfo) => {
  452. const tensor = tensors.map(valueInfo.name);
  453. tensor.type = onnx.Utility.formatType(valueInfo.type, context.imageFormat);
  454. tensor.description = valueInfo.doc_string;
  455. return tensor;
  456. });
  457. graph.output = graph.output.map((valueInfo) => {
  458. const tensor = tensors.map(valueInfo.name);
  459. tensor.type = onnx.Utility.formatType(valueInfo.type, context.imageFormat);
  460. tensor.description = valueInfo.doc_string;
  461. return tensor;
  462. });
  463. new onnx.Inference(graph.node, graph.output);
  464. const args = new Map();
  465. args.map = function(name) {
  466. if (!this.has(name)) {
  467. const tensor = tensors.map(name);
  468. const type = tensor.initializer ? tensor.initializer.type : tensor.type || null;
  469. this.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description));
  470. }
  471. return this.get(name);
  472. };
  473. this._nodes = onnx.Utility.createNodes(context, graph.node, graph.input, graph.output, tensors, args);
  474. for (const input of graph.input) {
  475. const argument = args.map(input.name);
  476. if (!argument.initializer) {
  477. this._inputs.push(new onnx.Parameter(input.name, [ argument ]));
  478. }
  479. }
  480. for (const output of graph.output) {
  481. const argument = args.map(output.name);
  482. if (!argument.initializer) {
  483. this._outputs.push(new onnx.Parameter(output.name, [ argument ]));
  484. }
  485. }
  486. }
  487. get name() {
  488. return this._name;
  489. }
  490. get description() {
  491. return this._description;
  492. }
  493. get groups() {
  494. return false;
  495. }
  496. get inputs() {
  497. return this._inputs;
  498. }
  499. get outputs() {
  500. return this._outputs;
  501. }
  502. get nodes() {
  503. return this._nodes;
  504. }
  505. toString() {
  506. return 'graph(' + this.name + ')';
  507. }
  508. };
  509. onnx.Parameter = class {
  510. constructor(name, args) {
  511. this._name = name;
  512. this._arguments = args;
  513. }
  514. get name() {
  515. return this._name;
  516. }
  517. get visible() {
  518. return true;
  519. }
  520. get arguments() {
  521. return this._arguments;
  522. }
  523. };
  524. onnx.Argument = class {
  525. constructor(name, type, initializer, annotation, description) {
  526. if (typeof name !== 'string') {
  527. throw new onnx.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  528. }
  529. this._name = name;
  530. this._type = type || null;
  531. this._initializer = initializer || null;
  532. this._annotation = annotation;
  533. this._description = description || '';
  534. }
  535. get name() {
  536. return this._name;
  537. }
  538. get type() {
  539. return this._type;
  540. }
  541. get description() {
  542. return this._description;
  543. }
  544. get quantization() {
  545. if (this._annotation) {
  546. return Object.keys(this._annotation).map((key) => key + ': ' + this._annotation[key]).join(', ');
  547. }
  548. return null;
  549. }
  550. get initializer() {
  551. return this._initializer;
  552. }
  553. };
  554. onnx.Node = class {
  555. constructor(context, op_type, domain, name, description, attributes, inputs, outputs) {
  556. this._type = context.metadata.type(op_type, domain) || { name: op_type, module: domain };
  557. if (this.type.module !== domain && !(this._type instanceof onnx.Function)) {
  558. this._type = Object.assign({}, this.type);
  559. this._type.name = op_type;
  560. this._type.module = domain;
  561. }
  562. this._name = name || '';
  563. this._description = description || '';
  564. this._inputs = inputs;
  565. this._outputs = outputs;
  566. this._attributes = (attributes || []).map((attribute) => new onnx.Attribute(context, op_type, domain, attribute));
  567. }
  568. get type() {
  569. return this._type;
  570. }
  571. get name() {
  572. return this._name;
  573. }
  574. get description() {
  575. return this._description;
  576. }
  577. get group() {
  578. return null;
  579. }
  580. get attributes() {
  581. return this._attributes;
  582. }
  583. get inputs() {
  584. return this._inputs;
  585. }
  586. get outputs() {
  587. return this._outputs;
  588. }
  589. };
  590. onnx.Attribute = class {
  591. constructor(context, op_type, domain, attribute) {
  592. this._name = attribute.name;
  593. this._description = attribute.doc_string || '';
  594. this._type = null;
  595. this._value = null;
  596. switch (attribute.type) {
  597. case onnx.AttributeType.FLOAT:
  598. this._value = attribute.f;
  599. this._type = 'float32';
  600. break;
  601. case onnx.AttributeType.INT:
  602. this._value = attribute.i;
  603. this._type = 'int64';
  604. break;
  605. case onnx.AttributeType.STRING:
  606. switch (op_type) {
  607. case 'Int8GivenTensorFill':
  608. this._value = Array.from(attribute.s);
  609. break;
  610. default:
  611. this._value = onnx.Utility.decodeText(attribute.s);
  612. break;
  613. }
  614. this._type = 'string';
  615. break;
  616. case onnx.AttributeType.TENSOR:
  617. this._value = new onnx.Tensor(attribute.t);
  618. this._type = 'tensor';
  619. break;
  620. case onnx.AttributeType.GRAPH:
  621. this._value = context.graph(attribute.g);
  622. this._type = 'graph';
  623. break;
  624. case onnx.AttributeType.FLOATS:
  625. this._value = attribute.floats;
  626. this._type = 'float32[]';
  627. break;
  628. case onnx.AttributeType.INTS:
  629. this._value = attribute.ints;
  630. this._type = 'int64[]';
  631. break;
  632. case onnx.AttributeType.STRINGS:
  633. this._value = attribute.strings.map((s) => onnx.Utility.decodeText(s));
  634. this._type = 'string[]';
  635. break;
  636. case onnx.AttributeType.TENSORS:
  637. this._value = attribute.tensors.map((tensor) => new onnx.Tensor(tensor));
  638. this._type = 'tensor[]';
  639. break;
  640. case onnx.AttributeType.GRAPHS:
  641. this._value = attribute.graphs.map((graph) => context.graph(graph));
  642. this._type = 'graph[]';
  643. break;
  644. case onnx.AttributeType.SPARSE_TENSOR:
  645. this._value = new onnx.Tensor(attribute.sparse_tensor);
  646. this._type = 'tensor';
  647. break;
  648. case onnx.AttributeType.SPARSE_TENSORS:
  649. this._value = attribute.sparse_tensors.map((tensor) => new onnx.Tensor(tensor));
  650. this._type = 'tensor[]';
  651. break;
  652. case onnx.AttributeType.TYPE_PROTO:
  653. this._value = onnx.Utility.formatType(attribute.tp, context.imageFormat);
  654. this._type = 'type';
  655. break;
  656. case onnx.AttributeType.TYPE_PROTOS:
  657. this._value = attribute.type_protos.map((type) => onnx.Utility.formatType(type, context.imageFormat));
  658. this._type = 'type[]';
  659. break;
  660. default:
  661. throw new onnx.Error("Unknown attribute type '" + attribute.type + "'.");
  662. }
  663. const metadata = context.metadata.attribute(op_type, domain, attribute.name);
  664. if (metadata && Object.prototype.hasOwnProperty.call(metadata, 'default') && this._value == metadata.default) {
  665. this._visible = false;
  666. }
  667. }
  668. get name() {
  669. return this._name;
  670. }
  671. get type() {
  672. return this._type;
  673. }
  674. get value() {
  675. return this._value;
  676. }
  677. get description() {
  678. return this._description;
  679. }
  680. get visible() {
  681. return this._visible == false ? false : true;
  682. }
  683. };
  684. onnx.Tensor = class {
  685. constructor(tensor, kind) {
  686. this._kind = kind || null;
  687. const data = (tensor) => {
  688. let data = undefined;
  689. if (tensor.data_location === onnx.DataLocation.DEFAULT) {
  690. switch (tensor.data_type) {
  691. case onnx.DataType.FLOAT16:
  692. if (tensor.int32_data && tensor.int32_data.length > 0) {
  693. const buffer = new Uint8Array(tensor.int32_data.length << 1);
  694. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  695. const array = tensor.int32_data;
  696. for (let i = 0; i < array.length; i++) {
  697. view.setUint16(i << 1, array[i], true);
  698. }
  699. data = {
  700. type: tensor.data_type,
  701. buffer: buffer
  702. };
  703. }
  704. break;
  705. case onnx.DataType.FLOAT:
  706. data = new Float32Array(tensor.float_data);
  707. break;
  708. case onnx.DataType.DOUBLE:
  709. data = new Float64Array(tensor.double_data);
  710. break;
  711. case onnx.DataType.BOOL:
  712. data = new Array(tensor.int32_data.size);
  713. for (let i = 0; i < data.length; i++) {
  714. data[i] = data[i] === 0 ? false : true;
  715. }
  716. break;
  717. case onnx.DataType.INT8:
  718. data = new Int8Array(tensor.int32_data);
  719. break;
  720. case onnx.DataType.UINT8:
  721. data = new Uint8Array(tensor.int32_data);
  722. break;
  723. case onnx.DataType.INT16:
  724. data = new Int32Array(tensor.int32_data);
  725. break;
  726. case onnx.DataType.UINT16:
  727. data = new Int32Array(tensor.int32_data);
  728. break;
  729. case onnx.DataType.INT32:
  730. data = new Int32Array(tensor.int32_data);
  731. break;
  732. case onnx.DataType.UINT32:
  733. case onnx.DataType.UINT64:
  734. data = tensor.uint64_data;
  735. break;
  736. case onnx.DataType.INT64:
  737. data = tensor.int64_data;
  738. break;
  739. }
  740. if (data && (Array.isArray(data) || ArrayBuffer.isView(data)) && data.length === 0) {
  741. data = undefined;
  742. }
  743. if (!data && tensor.raw_data && tensor.raw_data.length > 0) {
  744. data = {
  745. type: tensor.data_type,
  746. buffer: tensor.raw_data
  747. };
  748. }
  749. }
  750. return data;
  751. };
  752. const location = (tensor) => {
  753. return onnx.Utility.formatLocation(tensor.data_location);
  754. };
  755. if ((onnx.proto && tensor instanceof onnx.proto.SparseTensorProto) ||
  756. (onnx.schema && tensor instanceof onnx.schema.SparseTensor)) {
  757. this._name = tensor.values.name || '';
  758. this._type = new onnx.TensorType(tensor.values.data_type, new onnx.TensorShape(tensor.dims.map((dim) => dim)), null);
  759. this._location = Array.from(new Set([ location(tensor.values), location(tensor.indices) ])).join(':');
  760. this._values = data(tensor.values);
  761. this._indices = data(tensor.indices);
  762. }
  763. else {
  764. this._name = tensor.name || '';
  765. this._type = new onnx.TensorType(tensor.data_type, new onnx.TensorShape(tensor.dims.map((dim) => dim)), null);
  766. this._location = location(tensor);
  767. this._values = data(tensor);
  768. }
  769. }
  770. get name() {
  771. return this._name;
  772. }
  773. get kind() {
  774. return this._kind;
  775. }
  776. get type() {
  777. return this._type;
  778. }
  779. get state() {
  780. return this._context().state || null;
  781. }
  782. get value() {
  783. const context = this._context();
  784. if (context.state) {
  785. return null;
  786. }
  787. context.limit = Number.MAX_SAFE_INTEGER;
  788. return this._decode(context, 0);
  789. }
  790. toString() {
  791. const context = this._context();
  792. if (context.state) {
  793. return '';
  794. }
  795. context.limit = 10000;
  796. const value = this._decode(context, 0);
  797. return onnx.Tensor._stringify(value, '', ' ');
  798. }
  799. _context() {
  800. const context = {};
  801. context.state = null;
  802. if (this._sparse) {
  803. context.state = 'Sparse data not implemented.';
  804. return context;
  805. }
  806. if (this._location !== 'default') {
  807. context.state = "Data '" + this._location + "' location not implemented.";
  808. return context;
  809. }
  810. const decode = (data) => {
  811. if (!data || Array.isArray(data) || ArrayBuffer.isView(data)) {
  812. return data;
  813. }
  814. const buffer = data.buffer;
  815. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  816. const type = data.type;
  817. data = undefined;
  818. switch (type) {
  819. case onnx.DataType.BOOL:
  820. data = new Array(buffer.length);
  821. for (let i = 0; i < buffer.length; i++) {
  822. data[i] = view.getUint8(i) === 0 ? false : true;
  823. }
  824. break;
  825. case onnx.DataType.FLOAT16:
  826. data = new Float32Array(buffer.length >> 1);
  827. for (let i = 0; i < data.length; i++) {
  828. data[i] = view.getFloat16(i << 1, true);
  829. }
  830. break;
  831. case onnx.DataType.FLOAT:
  832. data = new Float32Array(buffer.length >> 2);
  833. for (let i = 0; i < data.length; i++) {
  834. data[i] = view.getFloat32(i << 2, true);
  835. }
  836. break;
  837. case onnx.DataType.DOUBLE:
  838. data = new Float64Array(buffer.length >> 3);
  839. for (let i = 0; i < data.length; i++) {
  840. data[i] = view.getFloat64(i << 3, true);
  841. }
  842. break;
  843. case onnx.DataType.INT8:
  844. data = new Int8Array(buffer.length);
  845. for (let i = 0; i < data.length; i++) {
  846. data[i] = view.getInt8(i, true);
  847. }
  848. break;
  849. case onnx.DataType.UINT8:
  850. data = new Uint8Array(buffer.length);
  851. for (let i = 0; i < data.length; i++) {
  852. data[i] = view.getUint8(i, true);
  853. }
  854. break;
  855. case onnx.DataType.INT16:
  856. data = new Int16Array(buffer.length >> 1);
  857. for (let i = 0; i < data.length; i++) {
  858. data[i] = view.getInt16(i << 1, true);
  859. }
  860. break;
  861. case onnx.DataType.UINT16:
  862. data = new Uint16Array(buffer.length >> 1);
  863. for (let i = 0; i < data.length; i++) {
  864. data[i] = view.getUint16(i << 1, true);
  865. }
  866. break;
  867. case onnx.DataType.INT32:
  868. data = new Int32Array(buffer.length >> 2);
  869. for (let i = 0; i < data.length; i++) {
  870. data[i] = view.getInt32(i << 2, true);
  871. }
  872. break;
  873. case onnx.DataType.UINT32:
  874. data = new Uint32Array(buffer.length >> 2);
  875. for (let i = 0; i < data.length; i++) {
  876. data[i] = view.getUint32(i << 2, true);
  877. }
  878. break;
  879. case onnx.DataType.INT64:
  880. data = new Array(buffer.length >> 3);
  881. for (let i = 0; i < data.length; i++) {
  882. data[i] = view.getInt64(i << 3, true);
  883. }
  884. break;
  885. case onnx.DataType.UINT64:
  886. data = new Array(buffer.length >> 3);
  887. for (let i = 0; i < data.length; i++) {
  888. data[i] = view.getUint64(i << 3, true);
  889. }
  890. break;
  891. }
  892. return data;
  893. };
  894. this._values = decode(this._values);
  895. if (!this._values) {
  896. context.state = 'Tensor data is empty.';
  897. return context;
  898. }
  899. this._indices = decode(this._indices);
  900. context.values = this._values;
  901. context.indices = this._indices;
  902. context.index = 0;
  903. context.dataType = this.type.dataType;
  904. context.shape = this.type.shape.dimensions;
  905. context.data = function() {
  906. if (!this._data) {
  907. if (this.indices && this.values && this.indices.length === this.values.length) {
  908. const size = context.shape.reduce((a, b) => a * b, 1);
  909. const indices = this.indices;
  910. const values = this.values;
  911. const array = new values.constructor(size);
  912. switch (this.dataType) {
  913. case 'boolean':
  914. array.fill(false);
  915. break;
  916. case 'int64':
  917. case 'uint64':
  918. break;
  919. }
  920. if (indices.length > 0) {
  921. if (Object.prototype.hasOwnProperty.call(indices[0], 'low')) {
  922. for (let i = 0; i < indices.length; i++) {
  923. const index = indices[i];
  924. array[index.high === 0 ? index.low : index.toNumber()] = values[i];
  925. }
  926. }
  927. else {
  928. for (let i = 0; i < indices.length; i++) {
  929. array[indices[i]] = values[i];
  930. }
  931. }
  932. }
  933. this._data = array;
  934. }
  935. else {
  936. this._data = this.values;
  937. }
  938. }
  939. return this._data;
  940. };
  941. return context;
  942. }
  943. _decode(context, dimension) {
  944. const shape = context.shape.length !== 0 ? context.shape : [ 1 ];
  945. const results = [];
  946. const size = shape[dimension];
  947. const data = context.data();
  948. if (dimension == shape.length - 1) {
  949. for (let i = 0; i < size; i++) {
  950. if (context.index > context.limit) {
  951. results.push('...');
  952. return results;
  953. }
  954. results.push(data[context.index++]);
  955. }
  956. }
  957. else {
  958. for (let j = 0; j < size; j++) {
  959. if (context.index > context.limit) {
  960. results.push('...');
  961. return results;
  962. }
  963. results.push(this._decode(context, dimension + 1));
  964. }
  965. }
  966. if (context.shape.length == 0) {
  967. return results[0];
  968. }
  969. return results;
  970. }
  971. static _stringify(value, indentation, indent) {
  972. if (Array.isArray(value)) {
  973. const result = [];
  974. result.push(indentation + '[');
  975. const items = value.map((item) => onnx.Tensor._stringify(item, indentation + indent, indent));
  976. if (items.length > 0) {
  977. result.push(items.join(',\n'));
  978. }
  979. result.push(indentation + ']');
  980. return result.join('\n');
  981. }
  982. if (typeof value == 'string') {
  983. return indentation + value;
  984. }
  985. if (value == Infinity) {
  986. return indentation + 'Infinity';
  987. }
  988. if (value == -Infinity) {
  989. return indentation + '-Infinity';
  990. }
  991. if (isNaN(value)) {
  992. return indentation + 'NaN';
  993. }
  994. return indentation + value.toString();
  995. }
  996. };
  997. onnx.TensorType = class {
  998. constructor(dataType, shape, denotation) {
  999. this._dataType = onnx.Utility.formatElementType(dataType);
  1000. this._shape = shape;
  1001. this._denotation = denotation || null;
  1002. }
  1003. get dataType() {
  1004. return this._dataType;
  1005. }
  1006. get shape() {
  1007. return this._shape;
  1008. }
  1009. get denotation() {
  1010. return this._denotation;
  1011. }
  1012. toString() {
  1013. return this.dataType + this._shape.toString();
  1014. }
  1015. };
  1016. onnx.TensorShape = class {
  1017. constructor(dimensions) {
  1018. this._dimensions = dimensions;
  1019. }
  1020. get dimensions() {
  1021. return this._dimensions;
  1022. }
  1023. toString() {
  1024. if (!this._dimensions || this._dimensions.length == 0) {
  1025. return '';
  1026. }
  1027. return '[' + this._dimensions.join(',') + ']';
  1028. }
  1029. };
  1030. onnx.SequenceType = class {
  1031. constructor(elementType, denotation) {
  1032. this._elementType = elementType;
  1033. this._denotation = denotation;
  1034. }
  1035. get elementType() {
  1036. return this._elementType;
  1037. }
  1038. get dennotation() {
  1039. return this._dennotation;
  1040. }
  1041. toString() {
  1042. return 'sequence<' + this._elementType.toString() + '>';
  1043. }
  1044. };
  1045. onnx.MapType = class {
  1046. constructor(keyType, valueType, denotation) {
  1047. this._keyType = onnx.Utility.formatElementType(keyType);
  1048. this._valueType = valueType;
  1049. this._denotation = denotation;
  1050. }
  1051. get keyType() {
  1052. return this._keyType;
  1053. }
  1054. get valueType() {
  1055. return this._valueType;
  1056. }
  1057. get denotation() {
  1058. return this._denotation;
  1059. }
  1060. toString() {
  1061. return 'map<' + this._keyType + ',' + this._valueType.toString() + '>';
  1062. }
  1063. };
  1064. onnx.OpaqueType = class {
  1065. constructor(domain, name) {
  1066. this._domain = domain;
  1067. this._name = name;
  1068. }
  1069. toString() {
  1070. const name = (this._domain ? (this._domain + '.') : '') + this._name;
  1071. return 'opaque<' + name + '>';
  1072. }
  1073. };
  1074. onnx.Function = class {
  1075. constructor(context, func) {
  1076. this._name = func.name;
  1077. this._domain = func.domain;
  1078. this._description = func.doc_string;
  1079. this._inputs = [];
  1080. this._outputs = [];
  1081. this._attributes = func.attribute.map((attribtue) => { return { name: attribtue }; });
  1082. const tensors = onnx.Utility.createTensors(func.node);
  1083. func.input = func.input.map((input) => tensors.map(input));
  1084. func.output = func.output.map((output) => tensors.map(output));
  1085. const args = new Map();
  1086. args.map = function(name) {
  1087. if (!this.has(name)) {
  1088. const tensor = tensors.map(name);
  1089. const type = tensor.initializer ? tensor.initializer.type : tensor.type || null;
  1090. this.set(name, new onnx.Argument(name, type, tensor.initializer, tensor.annotation, tensor.description));
  1091. }
  1092. return this.get(name);
  1093. };
  1094. this._nodes = onnx.Utility.createNodes(context, func.node, func.input, func.output, tensors, args);
  1095. for (const input of func.input) {
  1096. const argument = args.map(input.name);
  1097. if (!argument.initializer) {
  1098. this._inputs.push(new onnx.Parameter(input.name, [ argument ]));
  1099. }
  1100. }
  1101. for (const output of func.output) {
  1102. const argument = args.map(output.name);
  1103. if (!argument.initializer) {
  1104. this._outputs.push(new onnx.Parameter(output.name, [ argument ]));
  1105. }
  1106. }
  1107. }
  1108. get type() {
  1109. return 'function';
  1110. }
  1111. get name() {
  1112. return this._name;
  1113. }
  1114. get module() {
  1115. return this._domain;
  1116. }
  1117. get description() {
  1118. return this._description;
  1119. }
  1120. get inputs() {
  1121. return this._inputs;
  1122. }
  1123. get outputs() {
  1124. return this._outputs;
  1125. }
  1126. get attributes() {
  1127. return this._attributes;
  1128. }
  1129. get nodes() {
  1130. return this._nodes;
  1131. }
  1132. };
  1133. onnx.GraphMetadata = class {
  1134. constructor(metadata, imports) {
  1135. this._metadata = metadata;
  1136. this._imports = imports;
  1137. this._cache = new Map();
  1138. this._attributeCache = new Map();
  1139. this._functions = new Map();
  1140. }
  1141. add(func) {
  1142. if (!this._functions.has(func.module)) {
  1143. this._functions.set(func.module, new Map());
  1144. }
  1145. const map = this._functions.get(func.module);
  1146. if (map.has(func.name)) {
  1147. throw new onnx.Error("Duplicate function identifier '" + func.module + '.' + func.name + "'.");
  1148. }
  1149. map.set(func.name, func);
  1150. }
  1151. type(name, domain) {
  1152. domain = domain || 'ai.onnx';
  1153. const key = domain + ':' + name;
  1154. if (!this._cache.has(key)) {
  1155. let value = this._metadata.type(name, domain, this._imports);
  1156. if (!value) {
  1157. if (this._functions.has(domain)) {
  1158. const map = this._functions.get(domain);
  1159. if (map.has(name)) {
  1160. value = map.get(name);
  1161. }
  1162. }
  1163. }
  1164. this._cache.set(key, value);
  1165. }
  1166. return this._cache.get(key);
  1167. }
  1168. attribute(type, domain, name) {
  1169. const key = domain + ':' + type + ':' + name;
  1170. if (!this._attributeCache.has(key)) {
  1171. const schema = this.type(type, domain);
  1172. if (schema && schema.attributes && schema.attributes.length > 0) {
  1173. for (const attribute of schema.attributes) {
  1174. this._attributeCache.set(type + ':' + attribute.name, attribute);
  1175. }
  1176. }
  1177. if (!this._attributeCache.has(key)) {
  1178. this._attributeCache.set(key, null);
  1179. }
  1180. }
  1181. return this._attributeCache.get(key);
  1182. }
  1183. };
  1184. onnx.Metadata = class {
  1185. static open(context) {
  1186. if (onnx.Metadata._metadata) {
  1187. return Promise.resolve(onnx.Metadata._metadata);
  1188. }
  1189. return context.request('onnx-metadata.json', 'utf-8', null).then((data) => {
  1190. onnx.Metadata._metadata = new onnx.Metadata(data);
  1191. return onnx.Metadata._metadata;
  1192. }).catch(() => {
  1193. onnx.Metadata._metadata = new onnx.Metadata(null);
  1194. return onnx.Metadata._metadata;
  1195. });
  1196. }
  1197. constructor(data) {
  1198. this._map = new Map();
  1199. if (data) {
  1200. const metadata = JSON.parse(data);
  1201. for (const item of metadata) {
  1202. if (!this._map.has(item.module)) {
  1203. this._map.set(item.module, new Map());
  1204. }
  1205. const map = this._map.get(item.module);
  1206. if (!map.has(item.name)) {
  1207. map.set(item.name, []);
  1208. }
  1209. map.get(item.name).push(item);
  1210. }
  1211. }
  1212. }
  1213. type(name, domain, imports) {
  1214. domain = domain || 'ai.onnx';
  1215. let current = null;
  1216. if (this._map.has(domain)) {
  1217. const map = this._map.get(domain);
  1218. if (map.has(name)) {
  1219. for (const metadata of map.get(name)) {
  1220. const matchVersion = current ? current.version : -1;
  1221. const importVersion = imports.get(metadata.module) || 0;
  1222. if (importVersion >= metadata.version && matchVersion < metadata.version) {
  1223. current = metadata;
  1224. }
  1225. }
  1226. }
  1227. }
  1228. return current;
  1229. }
  1230. };
  1231. onnx.Inference = class {
  1232. constructor(nodes, outputs) {
  1233. this._outputs = new Map();
  1234. for (const node of nodes) {
  1235. for (const output of node.output) {
  1236. this._outputs.set(output.name, node);
  1237. }
  1238. }
  1239. for (const output of outputs) {
  1240. this._infer(output.name);
  1241. }
  1242. }
  1243. _infer(output) {
  1244. if (this._outputs.has(output)) {
  1245. let hasInputShapes = true;
  1246. const node = this._outputs.get(output);
  1247. for (const input of node.input) {
  1248. if (!input.type) {
  1249. this._infer(input);
  1250. if (!input.type) {
  1251. hasInputShapes = false;
  1252. break;
  1253. }
  1254. }
  1255. }
  1256. if (hasInputShapes) {
  1257. // continue
  1258. }
  1259. }
  1260. }
  1261. };
  1262. onnx.DataLocation = {
  1263. DEFAULT: 0,
  1264. EXTERNAL: 1
  1265. };
  1266. onnx.DataType = {
  1267. UNDEFINED: 0,
  1268. FLOAT: 1,
  1269. UINT8: 2,
  1270. INT8: 3,
  1271. UINT16: 4,
  1272. INT16: 5,
  1273. INT32: 6,
  1274. INT64: 7,
  1275. STRING: 8,
  1276. BOOL: 9,
  1277. FLOAT16: 10,
  1278. DOUBLE: 11,
  1279. UINT32: 12,
  1280. UINT64: 13,
  1281. COMPLEX64: 14,
  1282. COMPLEX128: 15,
  1283. BFLOAT16: 16
  1284. };
  1285. onnx.AttributeType = {
  1286. UNDEFINED: 0,
  1287. FLOAT: 1,
  1288. INT: 2,
  1289. STRING: 3,
  1290. TENSOR: 4,
  1291. GRAPH: 5,
  1292. FLOATS: 6,
  1293. INTS: 7,
  1294. STRINGS: 8,
  1295. TENSORS: 9,
  1296. GRAPHS: 10,
  1297. SPARSE_TENSOR: 11,
  1298. SPARSE_TENSORS: 12,
  1299. TYPE_PROTO: 13,
  1300. TYPE_PROTOS: 14
  1301. };
  1302. onnx.Utility = class {
  1303. static decodeText(value) {
  1304. if (typeof value === 'string') {
  1305. return value;
  1306. }
  1307. onnx.Utility._utf8Decoder = onnx.Utility._utf8Decoder || new TextDecoder('utf-8');
  1308. return onnx.Utility._utf8Decoder.decode(value);
  1309. }
  1310. static formatElementType(elementType) {
  1311. if (!onnx.Utility._elementTypeMap) {
  1312. const map = {};
  1313. map[onnx.DataType.UNDEFINED] = 'UNDEFINED';
  1314. map[onnx.DataType.FLOAT] = 'float32';
  1315. map[onnx.DataType.UINT8] = 'uint8';
  1316. map[onnx.DataType.INT8] = 'int8';
  1317. map[onnx.DataType.UINT16] = 'uint16';
  1318. map[onnx.DataType.INT16] = 'int16';
  1319. map[onnx.DataType.INT32] = 'int32';
  1320. map[onnx.DataType.INT64] = 'int64';
  1321. map[onnx.DataType.STRING] = 'string';
  1322. map[onnx.DataType.BOOL] = 'boolean';
  1323. map[onnx.DataType.FLOAT16] = 'float16';
  1324. map[onnx.DataType.DOUBLE] = 'float64';
  1325. map[onnx.DataType.UINT32] = 'uint32';
  1326. map[onnx.DataType.UINT64] = 'uint64';
  1327. map[onnx.DataType.COMPLEX64] = 'complex64';
  1328. map[onnx.DataType.COMPLEX128] = 'complex128';
  1329. map[onnx.DataType.BFLOAT16] = 'bfloat16';
  1330. onnx.Utility._elementTypeMap = map;
  1331. }
  1332. const name = onnx.Utility._elementTypeMap[elementType];
  1333. if (name) {
  1334. return name;
  1335. }
  1336. return onnx.Utility._elementTypeMap[onnx.DataType.UNDEFINED];
  1337. }
  1338. static formatType(type, imageFormat) {
  1339. if (!type) {
  1340. return null;
  1341. }
  1342. let denotation = '';
  1343. switch (type.denotation) {
  1344. case 'TENSOR':
  1345. denotation = 'Tensor';
  1346. break;
  1347. case 'IMAGE':
  1348. denotation = 'Image' + (imageFormat ? '(' + imageFormat.join(',') + ')' : '');
  1349. break;
  1350. case 'AUDIO':
  1351. denotation = 'Audio';
  1352. break;
  1353. case 'TEXT':
  1354. denotation = 'Text';
  1355. break;
  1356. }
  1357. switch (type.value) {
  1358. case 'tensor_type': {
  1359. {
  1360. const tensor_type = type.tensor_type;
  1361. let shape = [];
  1362. if (tensor_type.shape && tensor_type.shape.dim) {
  1363. shape = tensor_type.shape.dim.map((dim) => dim.dim_param ? dim.dim_param : dim.dim_value);
  1364. }
  1365. return new onnx.TensorType(tensor_type.elem_type, new onnx.TensorShape(shape), denotation);
  1366. }
  1367. }
  1368. case 'sparse_tensor_type': {
  1369. const tensor_type = type.sparse_tensor_type;
  1370. let shape = [];
  1371. if (tensor_type.shape && tensor_type.shape.dim) {
  1372. shape = tensor_type.shape.dim.map((dim) => dim.dim_param ? dim.dim_param : dim.dim_value);
  1373. }
  1374. return new onnx.TensorType(tensor_type.elem_type, new onnx.TensorShape(shape), denotation);
  1375. }
  1376. case 'map_type': {
  1377. return new onnx.MapType(type.map_type.key_type, onnx.Utility.formatType(type.map_type.value_type, imageFormat), denotation);
  1378. }
  1379. case 'sequence_type': {
  1380. return new onnx.SequenceType(onnx.Utility.formatType(type.sequence_type.elem_type, imageFormat), denotation);
  1381. }
  1382. case 'opaque_type': {
  1383. return new onnx.OpaqueType(type.opaque_type.domain, type.opaque_type.name);
  1384. }
  1385. }
  1386. return null;
  1387. }
  1388. static formatLocation(location) {
  1389. if (!onnx.Utility._dataLocations) {
  1390. onnx.Utility._dataLocations = new Map(Object.keys(onnx.DataLocation).map((key) => [ onnx.DataLocation[key], key.toLowerCase() ]));
  1391. }
  1392. return onnx.Utility._dataLocations.get(location);
  1393. }
  1394. static attributeType(attribute) {
  1395. if (attribute.type) {
  1396. return attribute.type;
  1397. }
  1398. if (attribute.ints && attribute.ints.length > 0) {
  1399. return onnx.AttributeType.INTS;
  1400. }
  1401. else if (attribute.floats && attribute.floats.length > 0) {
  1402. return onnx.AttributeType.FLOATS;
  1403. }
  1404. else if (attribute.strings && attribute.strings.length > 0) {
  1405. return onnx.AttributeType.STRINGS;
  1406. }
  1407. else if (attribute.graphs && attribute.graphs.length > 0) {
  1408. return onnx.AttributeType.GRAPHS;
  1409. }
  1410. else if (attribute.s && attribute.s.length > 0) {
  1411. return onnx.AttributeType.STRING;
  1412. }
  1413. else if (Object.prototype.hasOwnProperty.call(attribute, 'f')) {
  1414. return onnx.AttributeType.FLOAT;
  1415. }
  1416. else if (Object.prototype.hasOwnProperty.call(attribute, 'i')) {
  1417. return onnx.AttributeType.INT;
  1418. }
  1419. else if (Object.prototype.hasOwnProperty.call(attribute, 't')) {
  1420. return onnx.AttributeType.TENSOR;
  1421. }
  1422. else if (Object.prototype.hasOwnProperty.call(attribute, 'g')) {
  1423. return onnx.AttributeType.GRAPH;
  1424. }
  1425. else if (Object.prototype.hasOwnProperty.call(attribute, 'sparse_tensor')) {
  1426. return onnx.AttributeType.SPARSE_TENSOR;
  1427. }
  1428. return onnx.AttributeType.UNDEFINED;
  1429. }
  1430. static createTensors(nodes) {
  1431. const tensors = new Map();
  1432. tensors.map = function(name) {
  1433. if (!this.has(name)) {
  1434. this.set(name, { name: name });
  1435. }
  1436. return this.get(name);
  1437. };
  1438. for (const node of nodes) {
  1439. node.input = node.input.map((name) => tensors.map(name));
  1440. node.output = node.output.map((name) => tensors.map(name));
  1441. node.param = {};
  1442. for (const attribute of node.attribute) {
  1443. attribute.type = onnx.Utility.attributeType(attribute);
  1444. }
  1445. }
  1446. return tensors;
  1447. }
  1448. static createNodes(context, nodes, inputs, outputs, tensors, args) {
  1449. const inputMap = new Map();
  1450. const outputMap = new Map();
  1451. for (const node of nodes) {
  1452. node.input.every((input) => inputMap.set(input.name, (inputMap.get(input) || 0) + 1));
  1453. node.output.every((output) => outputMap.set(output.name, (outputMap.get(output) || 0) + 1));
  1454. }
  1455. inputs.every((input) => inputMap.delete(input.name));
  1456. outputs.every((output) => outputMap.delete(output.name));
  1457. nodes = nodes.filter((node) => {
  1458. const constant = node &&
  1459. node.op_type === 'Constant' &&
  1460. node.attribute.length === 1 && node.attribute[0] &&
  1461. node.input.length === 0 &&
  1462. node.output.length === 1 && node.output[0] && inputMap.get(node.output[0].name) === 1 && outputMap.get(node.output[0].name) === 1;
  1463. const attribute = constant ? node.attribute[0] : null;
  1464. if (attribute && attribute.name === 'value' && attribute.type === onnx.AttributeType.TENSOR && attribute.t) {
  1465. const tensor = tensors.map(node.output[0].name);
  1466. tensor.initializer = new onnx.Tensor(attribute.t, 'Constant');
  1467. return false;
  1468. }
  1469. else if (attribute && attribute.name === 'sparse_value' && attribute.type === onnx.AttributeType.SPARSE_TENSOR && attribute.sparse_tensor) {
  1470. const tensor = tensors.map(node.output[0].name);
  1471. tensor.initializer = new onnx.Tensor(attribute.sparse_tensor, 'Sparse Constant');
  1472. return false;
  1473. }
  1474. return true;
  1475. });
  1476. return nodes.map((node) => {
  1477. const schema = context.metadata.type(node.op_type, node.domain);
  1478. const inputs = [];
  1479. node.input = node.input || [];
  1480. for (let i = 0; i < node.input.length; ) {
  1481. const input = schema && schema.inputs && i < schema.inputs.length ? schema.inputs[i] : { name: i.toString() };
  1482. const count = input.list ? node.input.length - i : 1;
  1483. const list = node.input.slice(i, i + count).map((input) => args.map(input.name));
  1484. inputs.push(new onnx.Parameter(input.name, list));
  1485. i += count;
  1486. }
  1487. const outputs = [];
  1488. node.output = node.output || [];
  1489. for (let i = 0; i < node.output.length; ) {
  1490. const output = schema && schema.outputs && i < schema.outputs.length ? schema.outputs[i] : { name: i.toString() };
  1491. const count = output.list ? node.output.length - i : 1;
  1492. const list = node.output.slice(i, i + count).map((output) => args.map(output.name));
  1493. outputs.push(new onnx.Parameter(output.name, list));
  1494. i += count;
  1495. }
  1496. return new onnx.Node(context, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs);
  1497. });
  1498. }
  1499. };
  1500. onnx.Error = class extends Error {
  1501. constructor(message) {
  1502. super(message);
  1503. this.name = 'Error loading ONNX model.';
  1504. }
  1505. };
  1506. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1507. module.exports.ModelFactory = onnx.ModelFactory;
  1508. }