coreml.js 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570
  1. import * as base from './base.js';
  2. const coreml = {};
  3. coreml.ModelFactory = class {
  4. async match(context) {
  5. const stream = context.stream;
  6. const identifier = context.identifier.toLowerCase();
  7. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  8. const tags = await context.tags('pb');
  9. if (tags.get(1) === 0 && tags.get(2) === 2) {
  10. const match = (key) =>
  11. (key >= 200 && key < 220) || (key >= 300 && key < 320) || (key >= 400 && key < 420) ||
  12. (key >= 500 && key < 520) || (key >= 550 && key < 560) || (key >= 600 && key < 620) ||
  13. (key === 900) ||
  14. (key >= 2000 && key < 2010) || (key === 3000);
  15. if (extension === 'pb' && Array.from(tags.keys()).every((key) => !match(key))) {
  16. return null;
  17. }
  18. return context.set('coreml.pb');
  19. }
  20. if (extension === 'pbtxt') {
  21. const tags = await context.tags('pbtxt');
  22. if (tags.has('specificationVersion') && tags.has('description')) {
  23. return context.set('coreml.pbtxt');
  24. }
  25. }
  26. if (identifier === 'manifest.json') {
  27. const obj = await context.peek('json');
  28. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  29. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  30. if (entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel').length === 1)) {
  31. return context.set('coreml.manifest');
  32. }
  33. }
  34. }
  35. if (identifier === 'model.mil') {
  36. try {
  37. const reader = await context.read('text', 2048);
  38. const signature = reader.read('\n');
  39. if (signature && signature.trim().startsWith('program')) {
  40. return context.set('coreml.mil');
  41. }
  42. } catch {
  43. // continue regardless of error
  44. }
  45. }
  46. if (identifier === 'featuredescriptions.json') {
  47. const obj = await context.peek('json');
  48. if (obj && (obj.Inputs || obj.Outputs)) {
  49. return context.set('coreml.featuredescriptions');
  50. }
  51. }
  52. if (identifier === 'metadata.json') {
  53. const obj = await context.peek('json');
  54. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  55. return context.set('coreml.metadata');
  56. }
  57. if (Array.isArray(obj) && obj.some((item) => item && item.metadataOutputVersion && item.specificationVersion)) {
  58. return context.set('coreml.metadata.mlmodelc');
  59. }
  60. }
  61. if (extension === 'bin' && stream.length > 16) {
  62. const buffer = stream.peek(Math.min(256, stream.length));
  63. for (let i = 0; i < buffer.length - 4; i++) {
  64. const signature = (buffer[i] | buffer[i + 1] << 8 | buffer[i + 2] << 16 | buffer [i + 3] << 24) >>> 0;
  65. if (signature === 0xdeadbeef) {
  66. return context.set('coreml.weights');
  67. }
  68. }
  69. }
  70. return null;
  71. }
  72. filter(context, type) {
  73. if (context.type === 'coreml.metadata.mlmodelc' && (type === 'coreml.mil')) {
  74. return false;
  75. }
  76. return true;
  77. }
  78. async open(context) {
  79. coreml.proto = await context.require('./coreml-proto');
  80. coreml.proto = coreml.proto.CoreML.Specification;
  81. const metadata = await context.metadata('coreml-metadata.json');
  82. const openBinary = async (content, context, path, format) => {
  83. let model = null;
  84. try {
  85. const reader = await content.read('protobuf.binary');
  86. model = coreml.proto.Model.decode(reader);
  87. } catch (error) {
  88. const message = error && error.message ? error.message : error.toString();
  89. throw new coreml.Error(`File format is not coreml.Model (${message.replace(/\.$/, '')}).`);
  90. }
  91. const weightPaths = new Set();
  92. const walkProgram = (program) => {
  93. for (const func of Object.values(program.functions)) {
  94. for (const block of Object.values(func.block_specializations)) {
  95. for (const operation of block.operations) {
  96. for (const value of Object.values(operation.attributes)) {
  97. if (value.blobFileValue && value.blobFileValue.fileName) {
  98. weightPaths.add(value.blobFileValue.fileName);
  99. }
  100. }
  101. }
  102. }
  103. }
  104. };
  105. const walkModel = (model) => {
  106. if (model.mlProgram) {
  107. walkProgram(model.mlProgram);
  108. }
  109. if (model.pipeline && model.pipeline.models) {
  110. for (const node of model.pipeline.models) {
  111. walkModel(node);
  112. }
  113. }
  114. if (model.pipelineClassifier && model.pipelineClassifier.pipeline && model.pipelineClassifier.pipeline.models) {
  115. for (const node of model.pipelineClassifier.pipeline.models) {
  116. walkModel(node);
  117. }
  118. }
  119. if (model.pipelineRegressor && model.pipelineRegressor.pipeline && model.pipelineRegressor.pipeline.models) {
  120. for (const node of model.pipelineRegressor.pipeline.models) {
  121. walkModel(node);
  122. }
  123. }
  124. };
  125. walkModel(model);
  126. const weights = new Map();
  127. if (weightPaths.size > 0) {
  128. const folder = path.replace(/\/[^/]*$/, '');
  129. const keys = Array.from(weightPaths);
  130. const paths = keys.map((path) => path.replace(/^@model_path\//, `${folder}/`));
  131. try {
  132. const contexts = await Promise.all(paths.map((path) => context.fetch(path)));
  133. for (let i = 0; i < keys.length; i++) {
  134. weights.set(keys[i], contexts[i].stream);
  135. }
  136. } catch {
  137. // continue regardless of error
  138. }
  139. }
  140. format = format || 'Core ML';
  141. format = `${format} v${model.specificationVersion}`;
  142. context = new coreml.Context(metadata, format, model, weights);
  143. return new coreml.Model(context);
  144. };
  145. const openText = async (context) => {
  146. let model = null;
  147. try {
  148. const reader = await context.read('protobuf.text');
  149. model = coreml.proto.Model.decodeText(reader);
  150. } catch (error) {
  151. const message = error && error.message ? error.message : error.toString();
  152. throw new coreml.Error(`File format is not coreml.Model (${message.replace(/\.$/, '')}).`);
  153. }
  154. const format = `Core ML v${model.specificationVersion}`;
  155. context = new coreml.Context(metadata, format, model);
  156. return new coreml.Model(context, null);
  157. };
  158. const openManifest = async (obj, context, path) => {
  159. const entries = Object.values(obj.itemInfoEntries).filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel'));
  160. if (entries.length !== 1) {
  161. throw new coreml.Error('Manifest does not contain Core ML model.');
  162. }
  163. const name = `${path}Data/${entries[0].path}`;
  164. const content = await context.fetch(name);
  165. return openBinary(content, context, name, 'Core ML Package');
  166. };
  167. const openManifestStream = async (context, path) => {
  168. const name = `${path}Manifest.json`;
  169. const content = await context.fetch(name);
  170. const obj = await content.read('json');
  171. return openManifest(obj, context, path);
  172. };
  173. switch (context.type) {
  174. case 'coreml.pb': {
  175. return openBinary(context, context, '');
  176. }
  177. case 'coreml.pbtxt': {
  178. return openText(context, context, '');
  179. }
  180. case 'coreml.manifest': {
  181. const obj = await context.peek('json');
  182. return openManifest(obj, context, '');
  183. }
  184. case 'coreml.featuredescriptions':
  185. case 'coreml.metadata': {
  186. return openManifestStream(context, '../../');
  187. }
  188. case 'coreml.metadata.mlmodelc': {
  189. throw new coreml.Error('Core ML Model Archive format is not supported.');
  190. }
  191. case 'coreml.mil': {
  192. throw new coreml.Error('Core ML MIL format is not supported.');
  193. }
  194. case 'coreml.weights': {
  195. return openManifestStream(context, '../../../');
  196. }
  197. default: {
  198. throw new coreml.Error(`Unsupported Core ML format '${context.type}'.`);
  199. }
  200. }
  201. }
  202. };
  203. coreml.Model = class {
  204. constructor(context) {
  205. this.format = context.format;
  206. this.metadata = Array.from(context.metadata);
  207. this.graphs = context.graphs.map((context) => new coreml.Graph(context));
  208. this.functions = context.functions.map((context) => new coreml.Graph(context));
  209. if (context.version) {
  210. this.version = context.version;
  211. }
  212. if (context.description) {
  213. this.description = context.description;
  214. }
  215. }
  216. };
  217. coreml.Graph = class {
  218. constructor(context) {
  219. this.name = context.name || '';
  220. this.type = context.type || '';
  221. this.description = context.description;
  222. this.groups = context.groups;
  223. for (const value of context.values.values()) {
  224. const name = value.name;
  225. const type = value.type;
  226. const description = value.description;
  227. const initializer = value.initializer;
  228. if (!value.value) {
  229. value.value = new coreml.Value(name, type, description, initializer);
  230. }
  231. }
  232. this.inputs = context.inputs.map((argument) => {
  233. const values = argument.value.map((value) => value.value);
  234. return new coreml.Argument(argument.name, values, null, argument.visible);
  235. });
  236. this.outputs = context.outputs.map((argument) => {
  237. const values = argument.value.map((value) => value.value);
  238. return new coreml.Argument(argument.name, values, null, argument.visible);
  239. });
  240. for (const obj of context.nodes) {
  241. const attributes = obj.attributes;
  242. switch (obj.type) {
  243. case 'loop':
  244. attributes.conditionNetwork = new coreml.Graph(attributes.conditionNetwork);
  245. attributes.bodyNetwork = new coreml.Graph(attributes.bodyNetwork);
  246. break;
  247. case 'branch':
  248. attributes.ifBranch = new coreml.Graph(attributes.ifBranch);
  249. attributes.elseBranch = new coreml.Graph(attributes.elseBranch);
  250. break;
  251. default:
  252. break;
  253. }
  254. }
  255. this.nodes = context.nodes.map((obj) => new coreml.Node(context, obj));
  256. }
  257. };
  258. coreml.Argument = class {
  259. constructor(name, value, type, visible) {
  260. this.name = name;
  261. this.value = value;
  262. this.type = type || null;
  263. this.visible = visible !== false;
  264. }
  265. };
  266. coreml.Value = class {
  267. constructor(name, type, description, initializer) {
  268. if (typeof name !== 'string') {
  269. throw new coreml.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  270. }
  271. this.name = name;
  272. this.type = !type && initializer ? initializer.type : type;
  273. this.description = description || null;
  274. this.initializer = initializer || null;
  275. this.quantization = initializer ? initializer.quantization : null;
  276. }
  277. };
  278. coreml.Node = class {
  279. constructor(context, obj) {
  280. if (!obj.type) {
  281. throw new Error('Undefined node type.');
  282. }
  283. if (obj.group) {
  284. this.group = obj.group || null;
  285. }
  286. const type = context.metadata.type(obj.type);
  287. this.type = type ? { ...type } : { name: obj.type };
  288. this.type.name = obj.type.split(':').pop();
  289. this.name = obj.name || '';
  290. this.description = obj.description || '';
  291. this.inputs = (obj.inputs || []).map((argument) => {
  292. const values = argument.value.map((value) => value.value);
  293. return new coreml.Argument(argument.name, values, null, argument.visible);
  294. });
  295. this.outputs = (obj.outputs || []).map((argument) => {
  296. const values = argument.value.map((value) => value.value);
  297. return new coreml.Argument(argument.name, values, null, argument.visible);
  298. });
  299. this.attributes = Object.entries(obj.attributes || []).map(([name, value]) => {
  300. const metadata = context.metadata.attribute(obj.type, name);
  301. let type = null;
  302. let visible = true;
  303. if (value instanceof coreml.Tensor) {
  304. type = 'tensor';
  305. }
  306. if (value instanceof coreml.Graph) {
  307. type = 'graph';
  308. }
  309. if (metadata) {
  310. type = metadata.type ? metadata.type : type;
  311. if (type && coreml.proto) {
  312. value = coreml.Utility.enum(type, value);
  313. }
  314. if (metadata.visible === false) {
  315. visible = false;
  316. } else if (metadata.default !== undefined) {
  317. if (Array.isArray(value)) {
  318. value = value.map((item) => Number(item));
  319. }
  320. if (typeof value === 'bigint') {
  321. value = Number(value);
  322. }
  323. if (JSON.stringify(metadata.default) === JSON.stringify(value)) {
  324. visible = false;
  325. }
  326. }
  327. }
  328. return new coreml.Argument(name, value, type, visible);
  329. });
  330. if (Array.isArray(obj.chain)) {
  331. this.chain = obj.chain.map((obj) => new coreml.Node(context, obj));
  332. }
  333. }
  334. };
  335. coreml.Tensor = class {
  336. constructor(type, values, quantization, category) {
  337. this.type = type;
  338. this.values = values;
  339. this.category = category;
  340. if (type.dataType === 'float32') {
  341. this.encoding = '|';
  342. } else if ((type.dataType.startsWith('uint') && type.dataType.length === 5) ||
  343. (type.dataType.startsWith('int') && type.dataType.length === 4)) {
  344. this.encoding = '>';
  345. } else {
  346. this.encoding = '<';
  347. }
  348. if (quantization &&
  349. quantization.linearQuantization &&
  350. Array.isArray(quantization.linearQuantization.scale) &&
  351. Array.isArray(quantization.linearQuantization.bias)) {
  352. this.quantization = {
  353. type: 'linear',
  354. scale: quantization.linearQuantization.scale,
  355. bias: quantization.linearQuantization.bias
  356. };
  357. }
  358. if (quantization &&
  359. quantization.lookupTableQuantization &&
  360. quantization.lookupTableQuantization.floatValue &&
  361. quantization.lookupTableQuantization.floatValue.length > 0) {
  362. this.quantization = {
  363. type: 'lookup',
  364. value: quantization.lookupTableQuantization.floatValue
  365. };
  366. }
  367. }
  368. };
  369. coreml.TensorType = class {
  370. constructor(dataType, shape) {
  371. this.dataType = dataType;
  372. this.shape = shape || new coreml.TensorShape([]);
  373. }
  374. equals(obj) {
  375. return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
  376. }
  377. toString() {
  378. return this.dataType + this.shape.toString();
  379. }
  380. };
  381. coreml.TensorShape = class {
  382. constructor(dimensions) {
  383. this.dimensions = dimensions.map((dim) => typeof dim === 'bigint' ? dim.toNumber() : dim);
  384. }
  385. equals(obj) {
  386. return obj && Array.isArray(obj.dimensions) && Array.isArray(this.dimensions) &&
  387. this.dimensions.length === obj.dimensions.length &&
  388. obj.dimensions.every((value, index) => this.dimensions[index] === value);
  389. }
  390. toString() {
  391. return Array.isArray(this.dimensions) && this.dimensions.length > 0 ?
  392. `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]` : '';
  393. }
  394. };
  395. coreml.ListType = class {
  396. constructor(elementType) {
  397. this.elementType = elementType;
  398. }
  399. equals(obj) {
  400. return obj instanceof coreml.ListType && this.elementType.equals(obj.elementType);
  401. }
  402. toString() {
  403. return `list<${this.elementType}>`;
  404. }
  405. };
  406. coreml.MapType = class {
  407. constructor(keyType, valueType) {
  408. this.keyType = keyType;
  409. this.valueType = valueType;
  410. }
  411. equals(obj) {
  412. return obj instanceof coreml.MapType && this.keyType.equals(obj.keyType) && this.valueType.equals(obj.valueType);
  413. }
  414. toString() {
  415. return `map<${this.keyType},${this.valueType}>`;
  416. }
  417. };
  418. coreml.SequenceType = class {
  419. constructor(type) {
  420. this.type = type;
  421. }
  422. equals(obj) {
  423. return obj instanceof coreml.SequenceType && this.type.equals(obj.type);
  424. }
  425. toString() {
  426. return `sequence<${this.type}>`;
  427. }
  428. };
  429. coreml.ImageType = class {
  430. constructor(colorSpace, width, height) {
  431. this.width = width;
  432. this.height = height;
  433. switch (colorSpace) {
  434. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE:
  435. this.colorSpace = 'grayscale';
  436. break;
  437. case coreml.proto.ImageFeatureType.ColorSpace.RGB:
  438. this.colorSpace = 'RGB';
  439. break;
  440. case coreml.proto.ImageFeatureType.ColorSpace.BGR:
  441. this.colorSpace = 'BGR';
  442. break;
  443. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE_FLOAT16:
  444. this.colorSpace = 'grayscale:float16';
  445. break;
  446. default:
  447. throw new coreml.Error(`Unsupported image color space '${colorSpace}'.`);
  448. }
  449. }
  450. equals(obj) {
  451. return obj instanceof coreml.ImageType && this.width === obj.width && this.height === obj.height && this.colorSpace === obj.colorSpace;
  452. }
  453. toString() {
  454. return `image<${this.colorSpace},${this.width.toString()}x${this.height}>`;
  455. }
  456. };
  457. coreml.OptionalType = class {
  458. constructor(type) {
  459. this.type = type;
  460. }
  461. equals(obj) {
  462. return obj instanceof coreml.OptionalType && this.type.equals(obj.type);
  463. }
  464. toString() {
  465. return `optional<${this.type}>`;
  466. }
  467. };
  468. coreml.StateType = class {
  469. constructor(type) {
  470. this.type = type;
  471. }
  472. equals(obj) {
  473. return obj instanceof coreml.StateType && this.type.equals(obj.type);
  474. }
  475. toString() {
  476. return `state<${this.type}>`;
  477. }
  478. };
  479. coreml.Context = class {
  480. constructor(metadata, format, model, weights, values) {
  481. this.format = format;
  482. this.metadata = [];
  483. this.graphs = [];
  484. this.functions = [];
  485. const description = model.description;
  486. for (const func of description.functions) {
  487. const graph = new coreml.Context.Graph(metadata, func.name, 'function', model, func, weights, values);
  488. this.functions.push(graph);
  489. }
  490. if (description && description.defaultFunctionName) {
  491. const graph = this.graphs.find((graph) => graph.name === description.defaultFunctionName);
  492. if (graph) {
  493. this.functions.splice(this.graphs.indexOf(graph), 1);
  494. this.functions.unshift(graph);
  495. }
  496. }
  497. if (model && !model.mlProgram || (model.mlProgram.functions && model.mlProgram.functions.main)) {
  498. const graph = new coreml.Context.Graph(metadata, '', 'graph', model, description, weights, values);
  499. this.graphs.push(graph);
  500. }
  501. if (description && description.metadata) {
  502. const metadata = description.metadata;
  503. if (metadata.versionString) {
  504. this.version = metadata.versionString;
  505. }
  506. if (metadata.shortDescription) {
  507. this.description = metadata.shortDescription;
  508. }
  509. if (metadata.author) {
  510. this.metadata.push(new coreml.Argument('author', metadata.author));
  511. }
  512. if (metadata.license) {
  513. this.metadata.push(new coreml.Argument('license', metadata.license));
  514. }
  515. if (metadata.userDefined && Object.keys(metadata.userDefined).length > 0) {
  516. /* empty */
  517. }
  518. }
  519. }
  520. };
  521. coreml.Context.Graph = class {
  522. constructor(metadata, name, type, model, description, weights, values) {
  523. this.metadata = metadata;
  524. this.name = name;
  525. this.type = type;
  526. this.weights = weights || new Map();
  527. this.values = values || new Map();
  528. this.nodes = [];
  529. this.inputs = [];
  530. this.outputs = [];
  531. if (description) {
  532. const inputs = description && Array.isArray(description.input) ? description.input : [];
  533. for (const description of inputs) {
  534. const value = this.output(description.name);
  535. this.update(value, description);
  536. this.inputs.push({ name: description.name, visible: true, value: [value] });
  537. }
  538. const state = description && Array.isArray(description.state) ? description.state : [];
  539. for (const description of state) {
  540. const value = this.output(description.name);
  541. this.update(value, description);
  542. this.inputs.push({ name: description.name, visible: true, value: [value] });
  543. }
  544. this.description = this.model(model, '', description);
  545. const outputs = description && Array.isArray(description.output) ? description.output : [];
  546. for (const description of outputs) {
  547. const value = this.input(description.name);
  548. this.update(value, description);
  549. this.outputs.push({ name: description.name, visible: true, value: [value] });
  550. }
  551. }
  552. }
  553. context() {
  554. return new coreml.Context.Graph(this.metadata, '', 'graph', null, null, this.weights, this.values);
  555. }
  556. network(obj) {
  557. const context = this.context();
  558. for (const layer of obj.layers) {
  559. const type = layer.layer;
  560. context.node(context.groups, type, layer.name, '', layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  561. }
  562. context.updatePreprocessing('', obj.preprocessing, null);
  563. context.description = 'Neural Network';
  564. return context;
  565. }
  566. input(name) {
  567. if (!this.values.has(name)) {
  568. this.values.set(name, { counter: 0, name, to: [], from: [] });
  569. }
  570. return this.values.get(name);
  571. }
  572. output(name) {
  573. if (this.values.has(name)) {
  574. const value = { ...this.values.get(name) };
  575. value.counter++;
  576. value.name = `${name}|${value.counter}`; // custom argument id
  577. this.values.set(name, value);
  578. this.values.set(value.name, value);
  579. } else {
  580. const value = { counter: 0, name, to: [], from: [] };
  581. this.values.set(name, value);
  582. const key = `${name}|${value.counter}`;
  583. this.values.set(key, value);
  584. }
  585. return this.values.get(name);
  586. }
  587. update(value, description) {
  588. if (!value.type) {
  589. value.type = coreml.Utility.featureType(description.type);
  590. }
  591. if (!value.description && description.shortDescription) {
  592. value.description = description.shortDescription;
  593. }
  594. }
  595. node(group, type, name, description, data, inputs, outputs, inputTensors, outputTensors) {
  596. const obj = {
  597. group,
  598. type,
  599. name,
  600. description,
  601. attributes: {},
  602. inputs: [],
  603. outputs: []
  604. };
  605. inputs = inputs.map((input, index) => {
  606. const value = this.input(input);
  607. if (!value.type && inputTensors && index < inputTensors.length) {
  608. const tensor = inputTensors[index];
  609. const shape = tensor && tensor.dimValue ? new coreml.TensorShape(tensor.dimValue) : null;
  610. value.type = new coreml.TensorType('?', shape);
  611. }
  612. return value;
  613. });
  614. outputs = outputs.map((output, index) => {
  615. const value = this.output(output);
  616. if (!value.type && outputTensors && index < outputTensors.length) {
  617. const tensor = outputTensors[index];
  618. const shape = tensor && tensor.dimValue ? new coreml.TensorShape(tensor.dimValue) : null;
  619. value.type = new coreml.TensorType('?', shape);
  620. }
  621. return value;
  622. });
  623. const initializers = [];
  624. const initializer = (type, name, shape, data) => {
  625. let dataType = '?';
  626. let quantization = null;
  627. let values = null;
  628. if (data) {
  629. if (data.floatValue && data.floatValue.length > 0) {
  630. values = data.floatValue;
  631. dataType = 'float32';
  632. } else if (data.float16Value && data.float16Value.length > 0) {
  633. values = data.float16Value; // byte[]
  634. dataType = 'float16';
  635. } else if (data.rawValue && data.rawValue.length > 0) {
  636. if (data.quantization) {
  637. values = data.rawValue;
  638. dataType = `uint${data.quantization.numberOfBits}`;
  639. } else {
  640. shape = [];
  641. }
  642. }
  643. quantization = data.quantization || null;
  644. }
  645. const tensorType = new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  646. const tensor = new coreml.Tensor(tensorType, values, quantization, 'Weights');
  647. const input = this.metadata.input(type, name);
  648. const visible = input && input.visible === false ? false : true;
  649. const value = { value: new coreml.Value('', null, null, tensor) };
  650. initializers.push({ name, visible, value: [value] });
  651. };
  652. const vector = (value) => {
  653. return (value && Object.keys(value).length === 1 && value.vector) ? value.vector : value;
  654. };
  655. const weights = (type, data) => {
  656. switch (type) {
  657. case 'convolution': {
  658. const weightsShape = [data.outputChannels, data.kernelChannels, data.kernelSize[0], data.kernelSize[1]];
  659. if (data.isDeconvolution) {
  660. weightsShape[0] = data.kernelChannels;
  661. weightsShape[1] = Math.floor(Number(data.outputChannels / (data.nGroups === 0 ? 1 : data.nGroups)));
  662. }
  663. initializer(type, 'weights', weightsShape, data.weights);
  664. if (data.hasBias) {
  665. initializer(type, 'bias', [data.outputChannels], data.bias);
  666. }
  667. return { 'weights': true, 'bias': data.hasBias };
  668. }
  669. case 'innerProduct':
  670. initializer(type, 'weights', [data.outputChannels, data.inputChannels], data.weights);
  671. if (data.hasBias) {
  672. initializer(type, 'bias', [data.outputChannels], data.bias);
  673. }
  674. return { 'weights': true, 'bias': data.hasBias };
  675. case 'batchnorm':
  676. initializer(type, 'gamma', [data.channels], data.gamma);
  677. initializer(type, 'beta', [data.channels], data.beta);
  678. if (data.mean) {
  679. initializer(type, 'mean', [data.channels], data.mean);
  680. }
  681. if (data.variance) {
  682. initializer(type, 'variance', [data.channels], data.variance);
  683. }
  684. return { 'gamma': true, 'beta': true, 'mean': true, 'variance': true };
  685. case 'embedding':
  686. initializer(type, 'weights', [data.inputDim, data.outputChannels], data.weights);
  687. return { 'weights': true };
  688. case 'loadConstant':
  689. case 'loadConstantND':
  690. initializer(type, 'data', data.shape, data.data);
  691. return { 'data': true };
  692. case 'scale':
  693. initializer(type, 'scale', data.shapeScale, data.scale);
  694. if (data.hasBias) {
  695. initializer(type, 'bias', data.shapeBias, data.bias);
  696. }
  697. return { 'scale': true, 'bias': data.hasBias };
  698. case 'bias':
  699. initializer(type, 'bias', data.shape, data.bias);
  700. return { 'bias': true };
  701. case 'simpleRecurrent':
  702. initializer(type, 'weights', [data.outputVectorSize, data.inputVectorSize], data.weightMatrix);
  703. initializer(type, 'recurrent', [data.outputVectorSize, data.inputVectorSize], data.recursionMatrix);
  704. if (data.hasBiasVectors) {
  705. initializer(type, 'bias', [data.outputVectorSize], data.biasVector);
  706. }
  707. return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
  708. case 'gru': {
  709. const recursionMatrixShape = [data.outputVectorSize, data.outputVectorSize];
  710. const weightMatrixShape = [data.outputVectorSize, data.inputVectorSize];
  711. const biasVectorShape = [data.outputVectorSize];
  712. initializer(type, 'updateGateWeightMatrix', weightMatrixShape, data.updateGateWeightMatrix);
  713. initializer(type, 'resetGateWeightMatrix', weightMatrixShape, data.resetGateWeightMatrix);
  714. initializer(type, 'outputGateWeightMatrix', weightMatrixShape, data.outputGateWeightMatrix);
  715. initializer(type, 'updateGateRecursionMatrix', recursionMatrixShape, data.updateGateRecursionMatrix);
  716. initializer(type, 'resetGateRecursionMatrix', recursionMatrixShape, data.resetGateRecursionMatrix);
  717. initializer(type, 'outputGateRecursionMatrix', recursionMatrixShape, data.outputGateRecursionMatrix);
  718. if (data.hasBiasVectors) {
  719. initializer(type, 'updateGateBiasVector', biasVectorShape, data.updateGateBiasVector);
  720. initializer(type, 'resetGateBiasVector', biasVectorShape, data.resetGateBiasVector);
  721. initializer(type, 'outputGateBiasVector', biasVectorShape, data.outputGateBiasVector);
  722. }
  723. return {
  724. 'updateGateWeightMatrix': true, 'resetGateWeightMatrix': true, 'outputGateWeightMatrix': true,
  725. 'updateGateRecursionMatrix': true, 'resetGateRecursionMatrix': true, 'outputGateRecursionMatrix': true,
  726. 'updateGateBiasVector': data.hasBiasVectors, 'resetGateBiasVector': data.hasBiasVectors, 'outputGateBiasVector': data.hasBiasVectors
  727. };
  728. }
  729. case 'uniDirectionalLSTM':
  730. case 'biDirectionalLSTM': {
  731. const count = (type === 'uniDirectionalLSTM') ? 1 : 2;
  732. const h = data.outputVectorSize;
  733. const x = data.inputVectorSize;
  734. for (let i = 0; i < count; i++) {
  735. const weights = count === 1 ? data.weightParams : data.weightParams[i];
  736. const suffix = (i === 0) ? '' : '_rev';
  737. initializer(type, `inputGateWeightMatrix${suffix}`, [h,x], weights.inputGateWeightMatrix);
  738. initializer(type, `forgetGateWeightMatrix${suffix}`, [h,x], weights.forgetGateWeightMatrix);
  739. initializer(type, `blockInputWeightMatrix${suffix}`, [h,x], weights.blockInputWeightMatrix);
  740. initializer(type, `outputGateWeightMatrix${suffix}`, [h,x], weights.outputGateWeightMatrix);
  741. initializer(type, `inputGateRecursionMatrix${suffix}`, [h,h], weights.inputGateRecursionMatrix);
  742. initializer(type, `forgetGateRecursionMatrix${suffix}`, [h,h],weights.forgetGateRecursionMatrix);
  743. initializer(type, `blockInputRecursionMatrix${suffix}`, [h,h], weights.blockInputRecursionMatrix);
  744. initializer(type, `outputGateRecursionMatrix${suffix}`, [h,h], weights.outputGateRecursionMatrix);
  745. if (data.params.hasBiasVectors) {
  746. initializer(type, `inputGateBiasVector${suffix}`, [h], weights.inputGateBiasVector);
  747. initializer(type, `forgetGateBiasVector${suffix}`, [h], weights.forgetGateBiasVector);
  748. initializer(type, `blockInputBiasVector${suffix}`, [h], weights.blockInputBiasVector);
  749. initializer(type, `outputGateBiasVector${suffix}`, [h], weights.outputGateBiasVector);
  750. }
  751. if (data.params.hasPeepholeVectors) {
  752. initializer(type, `inputGatePeepholeVector${suffix}`, [h], weights.inputGatePeepholeVector);
  753. initializer(type, `forgetGatePeepholeVector${suffix}`, [h], weights.forgetGatePeepholeVector);
  754. initializer(type, `outputGatePeepholeVector${suffix}`, [h], weights.outputGatePeepholeVector);
  755. }
  756. }
  757. return { 'weightParams': true };
  758. }
  759. case 'dictVectorizer':
  760. data.stringToIndex = vector(data.stringToIndex);
  761. return {};
  762. case 'wordTagger':
  763. data.modelParameterData = Array.from(data.modelParameterData);
  764. data.stringTags = vector(data.stringTags);
  765. return { tokensOutputFeatureName: true, tokenTagsOutputFeatureName: true, tokenLengthsOutputFeatureName: true, tokenLocationsOutputFeatureName: true };
  766. case 'textClassifier':
  767. data.modelParameterData = Array.from(data.modelParameterData);
  768. data.stringClassLabels = vector(data.stringClassLabels);
  769. return {};
  770. case 'nonMaximumSuppression':
  771. data.stringClassLabels = vector(data.stringClassLabels);
  772. return {};
  773. default:
  774. return {};
  775. }
  776. };
  777. if (data) {
  778. const attributes = obj.attributes;
  779. const map = weights(type, data, initializers);
  780. for (const [name, value] of Object.entries(data)) {
  781. if (!map[name]) {
  782. attributes[name] = value;
  783. }
  784. }
  785. switch (obj.type) {
  786. case 'loop':
  787. attributes.bodyNetwork = this.network(attributes.bodyNetwork);
  788. attributes.conditionNetwork = this.network(attributes.conditionNetwork);
  789. break;
  790. case 'branch':
  791. attributes.ifBranch = this.network(attributes.ifBranch);
  792. attributes.elseBranch = this.network(attributes.elseBranch);
  793. break;
  794. default:
  795. break;
  796. }
  797. }
  798. const metadata = this.metadata.type(type);
  799. for (let i = 0; i < inputs.length;) {
  800. const input = metadata && metadata.inputs && i < metadata.inputs.length ? metadata.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
  801. const count = input.type === 'Tensor[]' ? inputs.length - i : 1;
  802. const values = inputs.slice(i, i + count);
  803. obj.inputs.push({ name: input.name, visible: true, value: values });
  804. i += count;
  805. }
  806. obj.inputs.push(...initializers);
  807. for (let i = 0; i < outputs.length;) {
  808. const output = metadata && metadata.outputs && i < metadata.outputs.length ? metadata.outputs[i] : { name: i === 0 ? 'output' : i.toString() };
  809. const count = output.type === 'Tensor[]' ? outputs.length - i : 1;
  810. const args = outputs.slice(i, i + count);
  811. obj.outputs.push({ name: output.name, visible: true, value: args });
  812. i += count;
  813. }
  814. this.nodes.push(obj);
  815. return obj;
  816. }
  817. model(model, group, description) {
  818. this.groups |= group.length > 0;
  819. const shortDescription = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : '';
  820. switch (model.Type) {
  821. case 'neuralNetworkClassifier': {
  822. const neuralNetworkClassifier = model.neuralNetworkClassifier;
  823. for (const layer of neuralNetworkClassifier.layers) {
  824. const type = layer.layer;
  825. this.node(group, type, layer.name, group === '' ? '' : shortDescription, layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  826. }
  827. this.updateClassifierOutput(group, neuralNetworkClassifier, description);
  828. this.updatePreprocessing(group, neuralNetworkClassifier.preprocessing, description);
  829. return 'Neural Network Classifier';
  830. }
  831. case 'neuralNetwork': {
  832. const neuralNetwork = model.neuralNetwork;
  833. for (const layer of neuralNetwork.layers) {
  834. this.node(group, layer.layer, layer.name, group === '' ? '' : shortDescription, layer[layer.layer], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  835. }
  836. this.updatePreprocessing(group, neuralNetwork.preprocessing, description);
  837. return 'Neural Network';
  838. }
  839. case 'neuralNetworkRegressor': {
  840. const neuralNetworkRegressor = model.neuralNetworkRegressor;
  841. for (const layer of neuralNetworkRegressor.layers) {
  842. this.node(group, layer.layer, layer.name, shortDescription, layer[layer.layer], layer.input, layer.output);
  843. }
  844. this.updatePreprocessing(group, neuralNetworkRegressor, description);
  845. return 'Neural Network Regressor';
  846. }
  847. case 'pipeline': {
  848. for (let i = 0; i < model.pipeline.models.length; i++) {
  849. this.model(model.pipeline.models[i], `${group ? (`${group}/`) : ''}pipeline[${i}]`, description);
  850. }
  851. return 'Pipeline';
  852. }
  853. case 'pipelineClassifier': {
  854. for (let i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) {
  855. this.model(model.pipelineClassifier.pipeline.models[i], `${group ? (`${group}/`) : ''}pipelineClassifier[${i}]`, description);
  856. }
  857. return 'Pipeline Classifier';
  858. }
  859. case 'pipelineRegressor': {
  860. for (let i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) {
  861. this.model(model.pipelineRegressor.pipeline.models[i], `${group ? (`${group}/`) : ''}pipelineRegressor[${i}]`, description);
  862. }
  863. return 'Pipeline Regressor';
  864. }
  865. case 'glmClassifier': {
  866. this.node(group, 'glmClassifier', null, shortDescription,
  867. {
  868. classEncoding: model.glmClassifier.classEncoding,
  869. offset: model.glmClassifier.offset,
  870. weights: model.glmClassifier.weights
  871. },
  872. [model.description.input[0].name],
  873. [model.description.output[0].name]);
  874. this.updateClassifierOutput(group, model.glmClassifier, description);
  875. return 'Generalized Linear Classifier';
  876. }
  877. case 'glmRegressor': {
  878. this.node(group, 'glmRegressor', null, shortDescription,
  879. model.glmRegressor,
  880. [model.description.input[0].name],
  881. [model.description.output[0].name]);
  882. return 'Generalized Linear Regressor';
  883. }
  884. case 'treeEnsembleClassifier': {
  885. this.node(group, 'treeEnsembleClassifier', null, shortDescription,
  886. model.treeEnsembleClassifier.treeEnsemble,
  887. [model.description.input[0].name],
  888. [model.description.output[0].name]);
  889. this.updateClassifierOutput(group, model.treeEnsembleClassifier, description);
  890. return 'Tree Ensemble Classifier';
  891. }
  892. case 'treeEnsembleRegressor': {
  893. this.node(group, 'treeEnsembleRegressor', null, shortDescription,
  894. model.treeEnsembleRegressor.treeEnsemble,
  895. [model.description.input[0].name],
  896. [model.description.output[0].name]);
  897. return 'Tree Ensemble Regressor';
  898. }
  899. case 'supportVectorClassifier': {
  900. this.node(group, 'supportVectorClassifier', null, shortDescription,
  901. {
  902. coefficients: model.supportVectorClassifier.coefficients,
  903. denseSupportVectors: model.supportVectorClassifier.denseSupportVectors,
  904. kernel: model.supportVectorClassifier.kernel,
  905. numberOfSupportVectorsPerClass: model.supportVectorClassifier.numberOfSupportVectorsPerClass,
  906. probA: model.supportVectorClassifier.probA,
  907. probB: model.supportVectorClassifier.probB,
  908. rho: model.supportVectorClassifier.rho,
  909. supportVectors: model.supportVectorClassifier.supportVectors
  910. },
  911. [model.description.input[0].name],
  912. [model.description.output[0].name]);
  913. this.updateClassifierOutput(group, model.supportVectorClassifier, description);
  914. return 'Support Vector Classifier';
  915. }
  916. case 'supportVectorRegressor': {
  917. this.node(group, 'supportVectorRegressor', null, shortDescription,
  918. {
  919. coefficients: model.supportVectorRegressor.coefficients,
  920. kernel: model.supportVectorRegressor.kernel,
  921. rho: model.supportVectorRegressor.rho,
  922. supportVectors: model.supportVectorRegressor.supportVectors
  923. },
  924. [model.description.input[0].name],
  925. [model.description.output[0].name]);
  926. return 'Support Vector Regressor';
  927. }
  928. case 'oneHotEncoder': {
  929. const categoryType = model.oneHotEncoder.CategoryType;
  930. const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse };
  931. oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType];
  932. this.node(group, 'oneHotEncoder', null, shortDescription,
  933. oneHotEncoderParams,
  934. [model.description.input[0].name],
  935. [model.description.output[0].name]);
  936. return 'One Hot Encoder';
  937. }
  938. case 'imputer': {
  939. const imputedValue = model.imputer.ImputedValue;
  940. const replaceValue = model.imputer.ReplaceValue;
  941. const imputerParams = {};
  942. imputerParams[imputedValue] = model.imputer[imputedValue];
  943. imputerParams[replaceValue] = model.imputer[replaceValue];
  944. this.node(group, 'oneHotEncoder', null, shortDescription,
  945. imputerParams,
  946. [model.description.input[0].name],
  947. [model.description.output[0].name]);
  948. return 'Imputer';
  949. }
  950. case 'featureVectorizer': {
  951. this.node(group, 'featureVectorizer', null, shortDescription,
  952. model.featureVectorizer,
  953. model.description.input.map((item) => item.name),
  954. [model.description.output[0].name]);
  955. return 'Feature Vectorizer';
  956. }
  957. case 'dictVectorizer': {
  958. this.node(group, 'dictVectorizer', null, shortDescription,
  959. model.dictVectorizer,
  960. [model.description.input[0].name],
  961. [model.description.output[0].name]);
  962. return 'Dictionary Vectorizer';
  963. }
  964. case 'scaler': {
  965. this.node(group, 'scaler', null, shortDescription,
  966. model.scaler,
  967. [model.description.input[0].name],
  968. [model.description.output[0].name]);
  969. return 'Scaler';
  970. }
  971. case 'categoricalMapping': {
  972. this.node(group, 'categoricalMapping', null, shortDescription,
  973. model.categoricalMapping,
  974. [model.description.input[0].name],
  975. [model.description.output[0].name]);
  976. return 'Categorical Mapping';
  977. }
  978. case 'normalizer': {
  979. this.node(group, 'normalizer', null, shortDescription,
  980. model.normalizer,
  981. [model.description.input[0].name],
  982. [model.description.output[0].name]);
  983. return 'Normalizer';
  984. }
  985. case 'arrayFeatureExtractor': {
  986. this.node(group, 'arrayFeatureExtractor', null, shortDescription,
  987. { extractIndex: model.arrayFeatureExtractor.extractIndex },
  988. [model.description.input[0].name],
  989. [model.description.output[0].name]);
  990. return 'Array Feature Extractor';
  991. }
  992. case 'nonMaximumSuppression': {
  993. const nonMaximumSuppressionParams = {
  994. pickTop: model.nonMaximumSuppression.pickTop,
  995. stringClassLabels: model.nonMaximumSuppression.stringClassLabels,
  996. iouThreshold: model.nonMaximumSuppression.iouThreshold,
  997. confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold
  998. };
  999. this.node(group, 'nonMaximumSuppression', null, shortDescription,
  1000. nonMaximumSuppressionParams,
  1001. [
  1002. model.nonMaximumSuppression.confidenceInputFeatureName,
  1003. model.nonMaximumSuppression.coordinatesInputFeatureName,
  1004. model.nonMaximumSuppression.iouThresholdInputFeatureName,
  1005. model.nonMaximumSuppression.confidenceThresholdInputFeatureName,
  1006. ],
  1007. [
  1008. model.nonMaximumSuppression.confidenceOutputFeatureName,
  1009. model.nonMaximumSuppression.coordinatesOutputFeatureName
  1010. ]);
  1011. return 'Non Maximum Suppression';
  1012. }
  1013. case 'wordTagger': {
  1014. this.node(group, 'wordTagger', null, shortDescription,
  1015. model.wordTagger,
  1016. [model.description.input[0].name],
  1017. [
  1018. model.wordTagger.tokensOutputFeatureName,
  1019. model.wordTagger.tokenTagsOutputFeatureName,
  1020. model.wordTagger.tokenLocationsOutputFeatureName,
  1021. model.wordTagger.tokenLengthsOutputFeatureName
  1022. ]);
  1023. return 'Word Tagger';
  1024. }
  1025. case 'textClassifier': {
  1026. this.node(group, 'textClassifier', null, shortDescription,
  1027. model.textClassifier,
  1028. [model.description.input[0].name],
  1029. [model.description.output[0].name]);
  1030. return 'Text Classifier';
  1031. }
  1032. case 'visionFeaturePrint': {
  1033. const visionFeaturePrintParams = {
  1034. scene: model.visionFeaturePrint.scene
  1035. };
  1036. this.node(group, 'visionFeaturePrint', null, shortDescription,
  1037. visionFeaturePrintParams,
  1038. [model.description.input[0].name],
  1039. [model.description.output[0].name]);
  1040. return 'Vision Feature Print';
  1041. }
  1042. case 'soundAnalysisPreprocessing': {
  1043. this.node(group, 'soundAnalysisPreprocessing', null, shortDescription,
  1044. model.soundAnalysisPreprocessing,
  1045. [model.description.input[0].name],
  1046. [model.description.output[0].name]);
  1047. return 'Sound Analysis Preprocessing';
  1048. }
  1049. case 'kNearestNeighborsClassifier': {
  1050. this.node(group, 'kNearestNeighborsClassifier', null, shortDescription,
  1051. model.kNearestNeighborsClassifier,
  1052. [model.description.input[0].name],
  1053. [model.description.output[0].name]);
  1054. this.updateClassifierOutput(group, model.kNearestNeighborsClassifier, description);
  1055. return 'Nearest Neighbors Classifier';
  1056. }
  1057. case 'itemSimilarityRecommender': {
  1058. this.node(group, 'itemSimilarityRecommender', null, shortDescription,
  1059. {
  1060. itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector,
  1061. itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities
  1062. },
  1063. model.description.input.map((feature) => feature.name),
  1064. model.description.output.map((feature) => feature.name));
  1065. return 'Item Similarity Recommender';
  1066. }
  1067. case 'audioFeaturePrint': {
  1068. this.node(group, 'audioFeaturePrint', null, shortDescription,
  1069. model.audioFeaturePrint,
  1070. [model.description.input[0].name],
  1071. [model.description.output[0].name]);
  1072. return 'Audio Feature Print';
  1073. }
  1074. case 'linkedModel': {
  1075. this.node(group, 'linkedModel', null, shortDescription,
  1076. model.linkedModel.linkedModelFile,
  1077. [model.description.input[0].name],
  1078. [model.description.output[0].name]);
  1079. return 'Linked Model';
  1080. }
  1081. case 'customModel': {
  1082. this.node(group, 'customModel', null, shortDescription,
  1083. { className: model.customModel.className, parameters: model.customModel.parameters },
  1084. [model.description.input[0].name],
  1085. [model.description.output[0].name]);
  1086. return 'customModel';
  1087. }
  1088. case 'mlProgram': {
  1089. return this.program(model.mlProgram, group);
  1090. }
  1091. default: {
  1092. throw new coreml.Error(`Unsupported model type '${JSON.stringify(Object.keys(model))}'.`);
  1093. }
  1094. }
  1095. }
  1096. updateClassifierOutput(group, classifier, description) {
  1097. let labelProbabilityLayerName = classifier.labelProbabilityLayerName;
  1098. if (!labelProbabilityLayerName && this.nodes.length > 0) {
  1099. const node = this.nodes.slice(-1).pop();
  1100. if (node && node.outputs.length === 1 && node.outputs[0].value.length === 1) {
  1101. labelProbabilityLayerName = node.outputs[0].value[0].name;
  1102. }
  1103. }
  1104. let predictedFeatureName = description.predictedFeatureName;
  1105. let predictedProbabilitiesName = description.predictedProbabilitiesName;
  1106. if ((predictedFeatureName || predictedProbabilitiesName) && labelProbabilityLayerName && classifier.ClassLabels) {
  1107. predictedFeatureName = predictedFeatureName ? predictedFeatureName : '?';
  1108. predictedProbabilitiesName = predictedProbabilitiesName ? predictedProbabilitiesName : '?';
  1109. const labelProbabilityInput = `${labelProbabilityLayerName}:labelProbabilityLayerName`;
  1110. const values = new Set();
  1111. for (const node of this.nodes) {
  1112. for (const output of node.outputs) {
  1113. for (const value of output.value) {
  1114. if (value.name === labelProbabilityLayerName) {
  1115. value.name = labelProbabilityInput;
  1116. values.add(value);
  1117. }
  1118. }
  1119. }
  1120. }
  1121. this.values.set(labelProbabilityInput, this.values.get(labelProbabilityLayerName));
  1122. this.values.delete(labelProbabilityLayerName);
  1123. const type = classifier.ClassLabels;
  1124. const node = {
  1125. // group: this._group,
  1126. type,
  1127. name: null,
  1128. description: '',
  1129. attributes: classifier[type] || {}
  1130. };
  1131. node.inputs = [
  1132. { name: 'input', visible: true, value: Array.from(values) }
  1133. ];
  1134. node.outputs = [
  1135. { name: 'probabilities', visible: true, value: [this.output(predictedProbabilitiesName)] },
  1136. { name: 'feature', visible: true, value: [this.output(predictedFeatureName)] }
  1137. ];
  1138. this.nodes.push(node);
  1139. }
  1140. }
  1141. updatePreprocessing(group, preprocessings, description) {
  1142. if (preprocessings && preprocessings.length > 0) {
  1143. const preprocessingInput = description.input[0].name;
  1144. const inputNodes = [];
  1145. for (const node of this.nodes) {
  1146. if (node.inputs.some((input) => Array.isArray(input.value) && input.value.some((arg) => arg.name === preprocessingInput))) {
  1147. inputNodes.push(node);
  1148. }
  1149. }
  1150. let currentOutput = preprocessingInput;
  1151. let preprocessorOutput = null;
  1152. let preprocessorIndex = 0;
  1153. for (const preprocessing of preprocessings) {
  1154. const input = preprocessing.featureName ? preprocessing.featureName : currentOutput;
  1155. currentOutput = `${preprocessingInput}:${preprocessorIndex}`;
  1156. const preprocessor = preprocessing.preprocessor;
  1157. const node = this.node(group, preprocessor, null, '', preprocessing[preprocessor], [input], [currentOutput]);
  1158. /* eslint-disable prefer-destructuring */
  1159. preprocessorOutput = node.outputs[0].value[0];
  1160. /* eslint-enable prefer-destructuring */
  1161. preprocessorIndex++;
  1162. }
  1163. for (const node of inputNodes) {
  1164. for (const input of node.inputs) {
  1165. if (Array.isArray(input.value)) {
  1166. for (let i = 0; i < input.value.length; i++) {
  1167. if (input.value[i].name === preprocessingInput) {
  1168. input.value[i] = preprocessorOutput;
  1169. }
  1170. }
  1171. }
  1172. }
  1173. }
  1174. }
  1175. }
  1176. program(program, group) {
  1177. // need to handle functions other than main?
  1178. const name = this.name || 'main';
  1179. const main = program.functions[name];
  1180. // need to handle more than one block specialization?
  1181. const block_specializations = main.block_specializations;
  1182. const key = Object.keys(block_specializations).filter((key) => key.startsWith('CoreML')).shift();
  1183. const block = block_specializations[key];
  1184. const convertValue = (value) => {
  1185. switch (value.value) {
  1186. case 'immediateValue': {
  1187. const tensor = value.immediateValue.tensor;
  1188. const type = coreml.Utility.valueType(value.type);
  1189. let values = null;
  1190. switch (tensor.value) {
  1191. case 'ints':
  1192. values = tensor.ints.values;
  1193. break;
  1194. case 'strings':
  1195. values = tensor.strings.values;
  1196. break;
  1197. case 'bools':
  1198. values = tensor.bools.values;
  1199. break;
  1200. case 'floats':
  1201. values = tensor.floats.values;
  1202. break;
  1203. case 'bytes':
  1204. values = tensor.bytes.values;
  1205. break;
  1206. default:
  1207. throw new coreml.Error(`Unsupported tensor value '${tensor.value}'.`);
  1208. }
  1209. if (type.shape.dimensions.length === 0) {
  1210. [values] = values;
  1211. }
  1212. return values;
  1213. }
  1214. case 'blobFileValue': {
  1215. const type = coreml.Utility.valueType(value.type);
  1216. const blob = value.blobFileValue;
  1217. const offset = Number(blob.offset);
  1218. const file = blob.fileName;
  1219. let data = null;
  1220. const stream = this.weights.get(file);
  1221. if (stream) {
  1222. stream.seek(offset);
  1223. const buffer = stream.read(32);
  1224. const reader = base.BinaryReader.open(buffer);
  1225. const signature = reader.uint32();
  1226. if (signature === 0xdeadbeef) {
  1227. reader.uint32(); // dataType
  1228. const size = reader.uint64().toNumber();
  1229. const offset = reader.uint64().toNumber();
  1230. stream.seek(offset);
  1231. const length = (type.shape.dimensions || []).reduce((a, b) => a * b, 1);
  1232. switch (type.dataType) {
  1233. case 'float32': {
  1234. const buffer = stream.read(size);
  1235. data = new Float32Array(buffer.buffer, buffer.byteOffset, length).slice();
  1236. break;
  1237. }
  1238. case 'float16':
  1239. case 'int1': case 'int2': case 'int3': case 'int4': case 'int6': case 'int8': case 'int32':
  1240. case 'uint1': case 'uint2': case 'uint3': case 'uint4': case 'uint6': case 'uint8': case 'uint16': {
  1241. data = stream.read(size);
  1242. break;
  1243. }
  1244. default:
  1245. throw new coreml.Error(`Unsupported blob data type '${type.dataType}'.`);
  1246. }
  1247. }
  1248. }
  1249. return new coreml.Tensor(type, data, null, 'Blob');
  1250. }
  1251. default: {
  1252. throw new coreml.Error(`Unsupported value '${value.value}'.`);
  1253. }
  1254. }
  1255. };
  1256. const operations = block.operations.map((op) => {
  1257. const operation = {
  1258. type: op.type,
  1259. attributes: {}
  1260. };
  1261. for (const [key, value] of Object.entries(op.attributes)) {
  1262. operation.attributes[key] = convertValue(value);
  1263. }
  1264. operation.inputs = Object.entries(op.inputs).map(([name, input]) => {
  1265. const value = input.arguments.map((argument) => {
  1266. if (argument.value && argument.value.value && argument.value.blobFileValue) {
  1267. return { name: '', value: convertValue(argument.value) };
  1268. }
  1269. if (argument.name) {
  1270. const value = this.input(argument.name);
  1271. value.to.push(operation);
  1272. return value;
  1273. }
  1274. return { value: argument.value };
  1275. });
  1276. return { name, value };
  1277. });
  1278. operation.outputs = op.outputs.map((output) => {
  1279. const value = this.input(output.name);
  1280. value.type = coreml.Utility.valueType(output.type);
  1281. value.from.push(operation);
  1282. return { name: 'output', value: [value] };
  1283. });
  1284. return operation;
  1285. });
  1286. for (const op of operations) {
  1287. if (op.type === 'const' && op.inputs.length === 0 &&
  1288. op.outputs.length === 1 && op.outputs[0].value.length === 1) {
  1289. /* eslint-disable prefer-destructuring */
  1290. const value = op.outputs[0].value[0];
  1291. /* eslint-enable prefer-destructuring */
  1292. if (op.attributes && op.attributes.val) {
  1293. const type = value.type;
  1294. const data = op.attributes.val;
  1295. if (data instanceof Uint8Array && data.length === 2 &&
  1296. type.dataType === 'float16' && type.shape.dimensions.length === 0) {
  1297. const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
  1298. value.value = view.getFloat16(0, true);
  1299. } else {
  1300. value.value = data;
  1301. }
  1302. value.const = true;
  1303. op.delete = true;
  1304. }
  1305. }
  1306. }
  1307. for (const op of operations) {
  1308. for (const input of op.inputs) {
  1309. if (input.value.length > 1 && input.value.some((argument) => argument.const)) {
  1310. if (!input.value.every((argument) => argument.value instanceof coreml.Tensor)) {
  1311. for (const value of input.value) {
  1312. for (const from of value.from) {
  1313. from.delete = false;
  1314. }
  1315. delete value.value;
  1316. }
  1317. }
  1318. }
  1319. }
  1320. }
  1321. for (const op of operations.filter((op) => !op.delete)) {
  1322. op.inputs = op.inputs.filter((input) => {
  1323. if (input.value.every((value) => value.value === undefined || value.value instanceof coreml.Tensor)) {
  1324. return true;
  1325. }
  1326. op.attributes[input.name] = input.value.length === 1 ?
  1327. input.value[0].value :
  1328. input.value.map((argument) => argument.value[0]);
  1329. return false;
  1330. });
  1331. }
  1332. const mapValue = (name, value) => {
  1333. if (value.value instanceof coreml.Tensor) {
  1334. value.initializer = value.value;
  1335. delete value.value;
  1336. if (name === '') {
  1337. this.values.set(value, value);
  1338. return value;
  1339. }
  1340. }
  1341. if (!this.values.has(name)) {
  1342. this.values.set(name, value);
  1343. } else if ((value.type && !value.type.equals(this.values.get(name).type)) ||
  1344. (value.initializer && value.initializer !== this.values.get(name).initializer)) {
  1345. throw new coreml.Error(`Duplicate value '${name}'.`);
  1346. }
  1347. return this.values.get(name);
  1348. };
  1349. for (const op of operations.filter((op) => !op.delete)) {
  1350. for (const argument of op.inputs) {
  1351. for (const value of argument.value) {
  1352. mapValue(value.name, value);
  1353. }
  1354. }
  1355. for (const argument of op.outputs) {
  1356. for (const value of argument.value) {
  1357. mapValue(value.name, value);
  1358. }
  1359. }
  1360. }
  1361. for (const op of operations.filter((op) => !op.delete)) {
  1362. op.group = group;
  1363. op.type = `program:${op.type}`;
  1364. const metadata = this.metadata.type(op.type);
  1365. if (metadata && Array.isArray(metadata.inputs)) {
  1366. const map = new Map(metadata.inputs.map((input, index) => [input.name, index + 1]));
  1367. op.inputs.sort((a, b) => (map.get(a.name) || map.size) - (map.get(b.name) || map.size));
  1368. }
  1369. this.nodes.push(op);
  1370. }
  1371. return 'ML Program';
  1372. }
  1373. };
  1374. coreml.Utility = class {
  1375. static enum(name, value) {
  1376. let type = coreml.proto;
  1377. const parts = name.split('.');
  1378. while (type && parts.length > 0) {
  1379. type = type[parts.shift()];
  1380. }
  1381. if (type) {
  1382. coreml.Utility._enumKeyMap = coreml.Utility._enumKeyMap || new Map();
  1383. if (!coreml.Utility._enumKeyMap.has(name)) {
  1384. const map = new Map(Object.entries(type).map(([key, value]) => [value, key]));
  1385. coreml.Utility._enumKeyMap.set(name, map);
  1386. }
  1387. const map = coreml.Utility._enumKeyMap.get(name);
  1388. if (map.has(value)) {
  1389. return map.get(value);
  1390. }
  1391. }
  1392. return value;
  1393. }
  1394. static featureType(type) {
  1395. let result = '?';
  1396. if (type) {
  1397. switch (type.Type) {
  1398. case 'arrayType':
  1399. case 'multiArrayType': {
  1400. const arrayType = type[type.Type];
  1401. let shape = new coreml.TensorShape([]);
  1402. if (arrayType.shape && arrayType.shape.length > 0) {
  1403. shape = new coreml.TensorShape(arrayType.shape.map((dim) => Number(dim)));
  1404. }
  1405. let dataType = '';
  1406. const ArrayDataType = coreml.proto.ArrayFeatureType.ArrayDataType;
  1407. switch (arrayType.dataType) {
  1408. case ArrayDataType.INVALID_ARRAY_DATA_TYPE:
  1409. dataType = '?';
  1410. break;
  1411. case ArrayDataType.FLOAT16:
  1412. dataType = 'float16';
  1413. break;
  1414. case ArrayDataType.FLOAT32:
  1415. dataType = 'float32';
  1416. break;
  1417. case ArrayDataType.DOUBLE:
  1418. dataType = 'float64';
  1419. break;
  1420. case ArrayDataType.INT32:
  1421. dataType = 'int32';
  1422. break;
  1423. default:
  1424. throw new coreml.Error(`Unsupported array data type '${arrayType.dataType}'.`);
  1425. }
  1426. result = new coreml.TensorType(dataType, shape);
  1427. break;
  1428. }
  1429. case 'stringType': {
  1430. result = new coreml.TensorType('string');
  1431. break;
  1432. }
  1433. case 'doubleType': {
  1434. result = new coreml.TensorType('float64');
  1435. break;
  1436. }
  1437. case 'int64Type': {
  1438. result = new coreml.TensorType('int64');
  1439. break;
  1440. }
  1441. case 'dictionaryType': {
  1442. result = new coreml.MapType(type.dictionaryType.KeyType.replace('KeyType', ''), 'float64');
  1443. break;
  1444. }
  1445. case 'sequenceType': {
  1446. result = new coreml.SequenceType(coreml.Utility.featureType(type[type.Type]));
  1447. break;
  1448. }
  1449. case 'imageType': {
  1450. result = new coreml.ImageType(type.imageType.colorSpace, type.imageType.width, type.imageType.height);
  1451. break;
  1452. }
  1453. case 'stateType': {
  1454. result = new coreml.StateType(coreml.Utility.featureType(type.stateType));
  1455. break;
  1456. }
  1457. default: {
  1458. throw new coreml.Error(`Unsupported feature type '${type.Type}'.`);
  1459. }
  1460. }
  1461. if (type.isOptional) {
  1462. result = new coreml.OptionalType(result);
  1463. }
  1464. }
  1465. return result;
  1466. }
  1467. static tensorType(type) {
  1468. if (!coreml.Utility._dataTypes) {
  1469. coreml.Utility._dataTypes = new Map(Object.entries(coreml.proto.MILSpec.DataType).map((([key, value]) => [value, key.toLowerCase()])));
  1470. coreml.Utility._dataTypes.delete(0);
  1471. coreml.Utility._dataTypes.set(1, 'boolean');
  1472. }
  1473. const shape = type.dimensions.map((dim) => dim.constant ? dim.constant.size : '?');
  1474. const dataType = coreml.Utility._dataTypes.get(type.dataType);
  1475. if (!dataType) {
  1476. throw new coreml.Error(`Unsupported data type '${type.dataType}'.`);
  1477. }
  1478. return new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  1479. }
  1480. static valueType(type) {
  1481. switch (type.type) {
  1482. case 'tensorType':
  1483. return coreml.Utility.tensorType(type.tensorType);
  1484. case 'listType':
  1485. return new coreml.ListType(coreml.Utility.valueType(type.listType.type));
  1486. case 'dictionaryType':
  1487. return new coreml.MapType(coreml.Utility.valueType(type.dictionaryType.keyType), coreml.Utility.valueType(type.dictionaryType.valueType));
  1488. default:
  1489. throw new coreml.Error(`Unsupported value type '${type.type}'.`);
  1490. }
  1491. }
  1492. };
  1493. coreml.Error = class extends Error {
  1494. constructor(message) {
  1495. super(message);
  1496. this.name = 'Error loading Core ML model.';
  1497. }
  1498. };
  1499. export const ModelFactory = coreml.ModelFactory;