coreml.js 70 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569
  1. import * as base from './base.js';
  2. const coreml = {};
  3. coreml.ModelFactory = class {
  4. async match(context) {
  5. const stream = context.stream;
  6. const identifier = context.identifier.toLowerCase();
  7. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  8. const tags = await context.tags('pb');
  9. if (tags.get(1) === 0 && tags.get(2) === 2) {
  10. const match = (key) =>
  11. (key >= 200 && key < 220) || (key >= 300 && key < 320) || (key >= 400 && key < 420) ||
  12. (key >= 500 && key < 520) || (key >= 550 && key < 560) || (key >= 600 && key < 620) ||
  13. (key === 900) ||
  14. (key >= 2000 && key < 2010) || (key === 3000);
  15. if (extension === 'pb' && Array.from(tags.keys()).every((key) => !match(key))) {
  16. return null;
  17. }
  18. return context.set('coreml.pb');
  19. }
  20. if (extension === 'pbtxt') {
  21. const tags = await context.tags('pbtxt');
  22. if (tags.has('specificationVersion') && tags.has('description')) {
  23. return context.set('coreml.pbtxt');
  24. }
  25. }
  26. if (identifier === 'manifest.json') {
  27. const obj = await context.peek('json');
  28. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  29. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  30. if (entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel').length === 1)) {
  31. return context.set('coreml.manifest');
  32. }
  33. }
  34. }
  35. if (identifier === 'model.mil') {
  36. try {
  37. const reader = await context.read('text', 2048);
  38. const signature = reader.read('\n');
  39. if (signature && signature.trim().startsWith('program')) {
  40. return context.set('coreml.mil');
  41. }
  42. } catch {
  43. // continue regardless of error
  44. }
  45. }
  46. if (identifier === 'featuredescriptions.json') {
  47. const obj = await context.peek('json');
  48. if (obj && (obj.Inputs || obj.Outputs)) {
  49. return context.set('coreml.featuredescriptions');
  50. }
  51. }
  52. if (identifier === 'metadata.json') {
  53. const obj = await context.peek('json');
  54. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  55. return context.set('coreml.metadata');
  56. }
  57. if (Array.isArray(obj) && obj.some((item) => item && item.metadataOutputVersion && item.specificationVersion)) {
  58. return context.set('coreml.metadata.mlmodelc');
  59. }
  60. }
  61. if (extension === 'bin' && stream.length > 16) {
  62. const buffer = stream.peek(Math.min(256, stream.length));
  63. for (let i = 0; i < buffer.length - 4; i++) {
  64. const signature = (buffer[i] | buffer[i + 1] << 8 | buffer[i + 2] << 16 | buffer [i + 3] << 24) >>> 0;
  65. if (signature === 0xdeadbeef) {
  66. return context.set('coreml.weights');
  67. }
  68. }
  69. }
  70. return null;
  71. }
  72. filter(context, match) {
  73. if (context.type === 'coreml.metadata.mlmodelc' && (match.type === 'coreml.mil')) {
  74. return false;
  75. }
  76. return true;
  77. }
  78. async open(context) {
  79. coreml.proto = await context.require('./coreml-proto');
  80. coreml.proto = coreml.proto.CoreML.Specification;
  81. const metadata = await context.metadata('coreml-metadata.json');
  82. const openBinary = async (content, context, path, format) => {
  83. let model = null;
  84. try {
  85. const reader = await content.read('protobuf.binary');
  86. model = coreml.proto.Model.decode(reader);
  87. } catch (error) {
  88. const message = error && error.message ? error.message : error.toString();
  89. throw new coreml.Error(`File format is not coreml.Model (${message.replace(/\.$/, '')}).`);
  90. }
  91. const weightPaths = new Set();
  92. const walkProgram = (program) => {
  93. for (const func of Object.values(program.functions)) {
  94. for (const block of Object.values(func.block_specializations)) {
  95. for (const operation of block.operations) {
  96. for (const value of Object.values(operation.attributes)) {
  97. if (value.blobFileValue && value.blobFileValue.fileName) {
  98. weightPaths.add(value.blobFileValue.fileName);
  99. }
  100. }
  101. }
  102. }
  103. }
  104. };
  105. const walkModel = (model) => {
  106. if (model.mlProgram) {
  107. walkProgram(model.mlProgram);
  108. }
  109. if (model.pipeline && model.pipeline.models) {
  110. for (const node of model.pipeline.models) {
  111. walkModel(node);
  112. }
  113. }
  114. if (model.pipelineClassifier && model.pipelineClassifier.pipeline && model.pipelineClassifier.pipeline.models) {
  115. for (const node of model.pipelineClassifier.pipeline.models) {
  116. walkModel(node);
  117. }
  118. }
  119. if (model.pipelineRegressor && model.pipelineRegressor.pipeline && model.pipelineRegressor.pipeline.models) {
  120. for (const node of model.pipelineRegressor.pipeline.models) {
  121. walkModel(node);
  122. }
  123. }
  124. };
  125. walkModel(model);
  126. const weights = new Map();
  127. if (weightPaths.size > 0) {
  128. const folder = path.replace(/\/[^/]*$/, '');
  129. const keys = Array.from(weightPaths);
  130. const paths = keys.map((path) => path.replace(/^@model_path\//, `${folder}/`));
  131. try {
  132. const contexts = await Promise.all(paths.map((path) => context.fetch(path)));
  133. for (let i = 0; i < keys.length; i++) {
  134. weights.set(keys[i], contexts[i].stream);
  135. }
  136. } catch {
  137. // continue regardless of error
  138. }
  139. }
  140. format = format || 'Core ML';
  141. format = `${format} v${model.specificationVersion}`;
  142. context = new coreml.Context(metadata, format, model, weights);
  143. return new coreml.Model(context);
  144. };
  145. const openText = async (context) => {
  146. let model = null;
  147. try {
  148. const reader = await context.read('protobuf.text');
  149. model = coreml.proto.Model.decodeText(reader);
  150. } catch (error) {
  151. const message = error && error.message ? error.message : error.toString();
  152. throw new coreml.Error(`File format is not coreml.Model (${message.replace(/\.$/, '')}).`);
  153. }
  154. const format = `Core ML v${model.specificationVersion}`;
  155. context = new coreml.Context(metadata, format, model);
  156. return new coreml.Model(context, null);
  157. };
  158. const openManifest = async (obj, context, path) => {
  159. const entries = Object.values(obj.itemInfoEntries).filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel'));
  160. if (entries.length !== 1) {
  161. throw new coreml.Error('Manifest does not contain Core ML model.');
  162. }
  163. const name = `${path}Data/${entries[0].path}`;
  164. const content = await context.fetch(name);
  165. return openBinary(content, context, name, 'Core ML Package');
  166. };
  167. const openManifestStream = async (context, path) => {
  168. const name = `${path}Manifest.json`;
  169. const content = await context.fetch(name);
  170. const obj = await content.read('json');
  171. return openManifest(obj, context, path);
  172. };
  173. switch (context.type) {
  174. case 'coreml.pb': {
  175. return openBinary(context, context, '');
  176. }
  177. case 'coreml.pbtxt': {
  178. return openText(context, context, '');
  179. }
  180. case 'coreml.manifest': {
  181. const obj = await context.peek('json');
  182. return openManifest(obj, context, '');
  183. }
  184. case 'coreml.featuredescriptions':
  185. case 'coreml.metadata': {
  186. return openManifestStream(context, '../../');
  187. }
  188. case 'coreml.metadata.mlmodelc': {
  189. throw new coreml.Error('Core ML Model Archive format is not supported.');
  190. }
  191. case 'coreml.mil': {
  192. throw new coreml.Error('Core ML MIL format is not supported.');
  193. }
  194. case 'coreml.weights': {
  195. return openManifestStream(context, '../../../');
  196. }
  197. default: {
  198. throw new coreml.Error(`Unsupported Core ML format '${context.type}'.`);
  199. }
  200. }
  201. }
  202. };
  203. coreml.Model = class {
  204. constructor(context) {
  205. this.format = context.format;
  206. this.metadata = Array.from(context.metadata);
  207. this.modules = context.graphs.map((context) => new coreml.Graph(context));
  208. this.functions = context.functions.map((context) => new coreml.Graph(context));
  209. if (context.version) {
  210. this.version = context.version;
  211. }
  212. if (context.description) {
  213. this.description = context.description;
  214. }
  215. }
  216. };
  217. coreml.Graph = class {
  218. constructor(context) {
  219. this.name = context.name || '';
  220. this.type = context.type || '';
  221. this.description = context.description;
  222. this.groups = context.groups;
  223. for (const value of context.values.values()) {
  224. const name = value.name;
  225. const type = value.type;
  226. const description = value.description;
  227. const initializer = value.initializer;
  228. if (!value.value) {
  229. value.value = new coreml.Value(name, type, description, initializer);
  230. }
  231. }
  232. this.inputs = context.inputs.map((argument) => {
  233. const values = argument.value.map((value) => value.value);
  234. return new coreml.Argument(argument.name, values, null, argument.visible);
  235. });
  236. this.outputs = context.outputs.map((argument) => {
  237. const values = argument.value.map((value) => value.value);
  238. return new coreml.Argument(argument.name, values, null, argument.visible);
  239. });
  240. for (const obj of context.nodes) {
  241. const attributes = obj.attributes;
  242. switch (obj.type) {
  243. case 'loop':
  244. attributes.conditionNetwork = new coreml.Graph(attributes.conditionNetwork);
  245. attributes.bodyNetwork = new coreml.Graph(attributes.bodyNetwork);
  246. break;
  247. case 'branch':
  248. attributes.ifBranch = new coreml.Graph(attributes.ifBranch);
  249. attributes.elseBranch = new coreml.Graph(attributes.elseBranch);
  250. break;
  251. default:
  252. break;
  253. }
  254. }
  255. this.nodes = context.nodes.map((obj) => new coreml.Node(context, obj));
  256. }
  257. };
  258. coreml.Argument = class {
  259. constructor(name, value, type = null, visible = true) {
  260. this.name = name;
  261. this.value = value;
  262. this.type = type;
  263. this.visible = visible;
  264. }
  265. };
  266. coreml.Value = class {
  267. constructor(name, type, description = null, initializer = null) {
  268. if (typeof name !== 'string') {
  269. throw new coreml.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  270. }
  271. this.name = name;
  272. this.type = !type && initializer ? initializer.type : type;
  273. this.description = description;
  274. this.initializer = initializer;
  275. this.quantization = initializer ? initializer.quantization : null;
  276. }
  277. };
  278. coreml.Node = class {
  279. constructor(context, obj) {
  280. if (!obj.type) {
  281. throw new Error('Undefined node type.');
  282. }
  283. if (obj.group) {
  284. this.group = obj.group || null;
  285. }
  286. const type = context.metadata.type(obj.type);
  287. this.type = type ? { ...type } : { name: obj.type };
  288. this.type.name = obj.type.split(':').pop();
  289. this.name = obj.name || '';
  290. this.description = obj.description || '';
  291. this.inputs = (obj.inputs || []).map((argument) => {
  292. const values = argument.value.map((value) => value.value);
  293. return new coreml.Argument(argument.name, values, null, argument.visible);
  294. });
  295. this.outputs = (obj.outputs || []).map((argument) => {
  296. const values = argument.value.map((value) => value.value);
  297. return new coreml.Argument(argument.name, values, null, argument.visible);
  298. });
  299. this.attributes = Object.entries(obj.attributes || []).map(([name, value]) => {
  300. const metadata = context.metadata.attribute(obj.type, name);
  301. let type = null;
  302. let visible = true;
  303. if (value instanceof coreml.Tensor) {
  304. type = 'tensor';
  305. }
  306. if (value instanceof coreml.Graph) {
  307. type = 'graph';
  308. }
  309. if (metadata) {
  310. type = metadata.type ? metadata.type : type;
  311. if (type && coreml.proto) {
  312. value = coreml.Utility.enum(type, value);
  313. }
  314. if (metadata.visible === false) {
  315. visible = false;
  316. } else if (metadata.default !== undefined) {
  317. if (Array.isArray(value)) {
  318. value = value.map((item) => Number(item));
  319. }
  320. if (typeof value === 'bigint') {
  321. value = Number(value);
  322. }
  323. if (JSON.stringify(metadata.default) === JSON.stringify(value)) {
  324. visible = false;
  325. }
  326. }
  327. }
  328. return new coreml.Argument(name, value, type, visible);
  329. });
  330. if (Array.isArray(obj.chain)) {
  331. this.chain = obj.chain.map((obj) => new coreml.Node(context, obj));
  332. }
  333. }
  334. };
  335. coreml.Tensor = class {
  336. constructor(type, values, quantization, category) {
  337. this.type = type;
  338. this.values = values;
  339. this.category = category;
  340. if (type.dataType === 'float32') {
  341. this.encoding = '|';
  342. } else if ((type.dataType.startsWith('uint') && type.dataType.length === 5) ||
  343. (type.dataType.startsWith('int') && type.dataType.length === 4)) {
  344. this.encoding = '>';
  345. } else {
  346. this.encoding = '<';
  347. }
  348. if (quantization &&
  349. quantization.linearQuantization &&
  350. Array.isArray(quantization.linearQuantization.scale) &&
  351. Array.isArray(quantization.linearQuantization.bias)) {
  352. this.quantization = {
  353. type: 'linear',
  354. scale: quantization.linearQuantization.scale,
  355. bias: quantization.linearQuantization.bias
  356. };
  357. }
  358. if (quantization &&
  359. quantization.lookupTableQuantization &&
  360. quantization.lookupTableQuantization.floatValue &&
  361. quantization.lookupTableQuantization.floatValue.length > 0) {
  362. this.quantization = {
  363. type: 'lookup',
  364. value: quantization.lookupTableQuantization.floatValue
  365. };
  366. }
  367. }
  368. };
  369. coreml.TensorType = class {
  370. constructor(dataType, shape) {
  371. this.dataType = dataType;
  372. this.shape = shape || new coreml.TensorShape([]);
  373. }
  374. equals(obj) {
  375. return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
  376. }
  377. toString() {
  378. return this.dataType + this.shape.toString();
  379. }
  380. };
  381. coreml.TensorShape = class {
  382. constructor(dimensions) {
  383. this.dimensions = dimensions.map((dim) => typeof dim === 'bigint' ? dim.toNumber() : dim);
  384. }
  385. equals(obj) {
  386. return obj && Array.isArray(obj.dimensions) && Array.isArray(this.dimensions) &&
  387. this.dimensions.length === obj.dimensions.length &&
  388. obj.dimensions.every((value, index) => this.dimensions[index] === value);
  389. }
  390. toString() {
  391. return Array.isArray(this.dimensions) && this.dimensions.length > 0 ?
  392. `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]` : '';
  393. }
  394. };
  395. coreml.ListType = class {
  396. constructor(elementType) {
  397. this.elementType = elementType;
  398. }
  399. equals(obj) {
  400. return obj instanceof coreml.ListType && this.elementType.equals(obj.elementType);
  401. }
  402. toString() {
  403. return `list<${this.elementType}>`;
  404. }
  405. };
  406. coreml.MapType = class {
  407. constructor(keyType, valueType) {
  408. this.keyType = keyType;
  409. this.valueType = valueType;
  410. }
  411. equals(obj) {
  412. return obj instanceof coreml.MapType && this.keyType.equals(obj.keyType) && this.valueType.equals(obj.valueType);
  413. }
  414. toString() {
  415. return `map<${this.keyType},${this.valueType}>`;
  416. }
  417. };
  418. coreml.SequenceType = class {
  419. constructor(type) {
  420. this.type = type;
  421. }
  422. equals(obj) {
  423. return obj instanceof coreml.SequenceType && this.type.equals(obj.type);
  424. }
  425. toString() {
  426. return `sequence<${this.type}>`;
  427. }
  428. };
  429. coreml.ImageType = class {
  430. constructor(colorSpace, width, height) {
  431. this.width = width;
  432. this.height = height;
  433. switch (colorSpace) {
  434. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE:
  435. this.colorSpace = 'grayscale';
  436. break;
  437. case coreml.proto.ImageFeatureType.ColorSpace.RGB:
  438. this.colorSpace = 'RGB';
  439. break;
  440. case coreml.proto.ImageFeatureType.ColorSpace.BGR:
  441. this.colorSpace = 'BGR';
  442. break;
  443. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE_FLOAT16:
  444. this.colorSpace = 'grayscale:float16';
  445. break;
  446. default:
  447. throw new coreml.Error(`Unsupported image color space '${colorSpace}'.`);
  448. }
  449. }
  450. equals(obj) {
  451. return obj instanceof coreml.ImageType && this.width === obj.width && this.height === obj.height && this.colorSpace === obj.colorSpace;
  452. }
  453. toString() {
  454. return `image<${this.colorSpace},${this.width.toString()}x${this.height}>`;
  455. }
  456. };
  457. coreml.OptionalType = class {
  458. constructor(type) {
  459. this.type = type;
  460. }
  461. equals(obj) {
  462. return obj instanceof coreml.OptionalType && this.type.equals(obj.type);
  463. }
  464. toString() {
  465. return `optional<${this.type}>`;
  466. }
  467. };
  468. coreml.StateType = class {
  469. constructor(type) {
  470. this.type = type;
  471. }
  472. equals(obj) {
  473. return obj instanceof coreml.StateType && this.type.equals(obj.type);
  474. }
  475. toString() {
  476. return `state<${this.type}>`;
  477. }
  478. };
  479. coreml.Context = class {
  480. constructor(metadata, format, model, weights, values) {
  481. this.format = format;
  482. this.metadata = [];
  483. this.graphs = [];
  484. this.functions = [];
  485. const description = model.description;
  486. for (const func of description.functions) {
  487. const graph = new coreml.Context.Graph(metadata, func.name, 'function', model, func, weights, values);
  488. this.functions.push(graph);
  489. }
  490. if (description && description.defaultFunctionName) {
  491. const graph = this.graphs.find((graph) => graph.name === description.defaultFunctionName);
  492. if (graph) {
  493. this.functions.splice(this.graphs.indexOf(graph), 1);
  494. this.functions.unshift(graph);
  495. }
  496. }
  497. if (model && !model.mlProgram || (model.mlProgram.functions && model.mlProgram.functions.main)) {
  498. const graph = new coreml.Context.Graph(metadata, '', 'graph', model, description, weights, values);
  499. this.graphs.push(graph);
  500. }
  501. if (description && description.metadata) {
  502. const metadata = description.metadata;
  503. if (metadata.versionString) {
  504. this.version = metadata.versionString;
  505. }
  506. if (metadata.shortDescription) {
  507. this.description = metadata.shortDescription;
  508. }
  509. if (metadata.author) {
  510. this.metadata.push(new coreml.Argument('author', metadata.author));
  511. }
  512. if (metadata.license) {
  513. this.metadata.push(new coreml.Argument('license', metadata.license));
  514. }
  515. if (metadata.userDefined && Object.keys(metadata.userDefined).length > 0) {
  516. /* empty */
  517. }
  518. }
  519. }
  520. };
  521. coreml.Context.Graph = class {
  522. constructor(metadata, name, type, model, description, weights, values) {
  523. this.metadata = metadata;
  524. this.name = name;
  525. this.type = type;
  526. this.weights = weights || new Map();
  527. this.values = values || new Map();
  528. this.nodes = [];
  529. this.inputs = [];
  530. this.outputs = [];
  531. if (description) {
  532. const inputs = description && Array.isArray(description.input) ? description.input : [];
  533. for (const description of inputs) {
  534. const value = this.output(description.name);
  535. this.update(value, description);
  536. this.inputs.push({ name: description.name, visible: true, value: [value] });
  537. }
  538. const state = description && Array.isArray(description.state) ? description.state : [];
  539. for (const description of state) {
  540. const value = this.output(description.name);
  541. this.update(value, description);
  542. this.inputs.push({ name: description.name, visible: true, value: [value] });
  543. }
  544. this.description = this.model(model, '', description);
  545. const outputs = description && Array.isArray(description.output) ? description.output : [];
  546. for (const description of outputs) {
  547. const value = this.input(description.name);
  548. this.update(value, description);
  549. this.outputs.push({ name: description.name, visible: true, value: [value] });
  550. }
  551. }
  552. }
  553. context() {
  554. return new coreml.Context.Graph(this.metadata, '', 'graph', null, null, this.weights, this.values);
  555. }
  556. network(obj) {
  557. const context = this.context();
  558. for (const layer of obj.layers) {
  559. const type = layer.layer;
  560. context.node(context.groups, type, layer.name, '', layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  561. }
  562. context.updatePreprocessing('', obj.preprocessing, null);
  563. context.description = 'Neural Network';
  564. return context;
  565. }
  566. input(name) {
  567. if (!this.values.has(name)) {
  568. this.values.set(name, { counter: 0, name, to: [], from: [] });
  569. }
  570. return this.values.get(name);
  571. }
  572. output(name) {
  573. if (this.values.has(name)) {
  574. const value = { ...this.values.get(name) };
  575. value.counter++;
  576. value.name = `${name}|${value.counter}`; // custom argument id
  577. this.values.set(name, value);
  578. this.values.set(value.name, value);
  579. } else {
  580. const value = { counter: 0, name, to: [], from: [] };
  581. this.values.set(name, value);
  582. const key = `${name}|${value.counter}`;
  583. this.values.set(key, value);
  584. }
  585. return this.values.get(name);
  586. }
  587. update(value, description) {
  588. if (!value.type) {
  589. value.type = coreml.Utility.featureType(description.type);
  590. }
  591. if (!value.description && description.shortDescription) {
  592. value.description = description.shortDescription;
  593. }
  594. }
  595. node(group, type, name, description, data, inputs, outputs, inputTensors, outputTensors) {
  596. const obj = {
  597. group,
  598. type,
  599. name,
  600. description,
  601. attributes: {},
  602. inputs: [],
  603. outputs: []
  604. };
  605. inputs = inputs.map((input, index) => {
  606. const value = this.input(input);
  607. if (!value.type && inputTensors && index < inputTensors.length) {
  608. const tensor = inputTensors[index];
  609. const shape = tensor && tensor.dimValue ? new coreml.TensorShape(tensor.dimValue) : null;
  610. value.type = new coreml.TensorType('?', shape);
  611. }
  612. return value;
  613. });
  614. outputs = outputs.map((output, index) => {
  615. const value = this.output(output);
  616. if (!value.type && outputTensors && index < outputTensors.length) {
  617. const tensor = outputTensors[index];
  618. const shape = tensor && tensor.dimValue ? new coreml.TensorShape(tensor.dimValue) : null;
  619. value.type = new coreml.TensorType('?', shape);
  620. }
  621. return value;
  622. });
  623. const initializers = [];
  624. const initializer = (type, name, shape, data) => {
  625. let dataType = '?';
  626. let quantization = null;
  627. let values = null;
  628. if (data) {
  629. if (data.floatValue && data.floatValue.length > 0) {
  630. values = data.floatValue;
  631. dataType = 'float32';
  632. } else if (data.float16Value && data.float16Value.length > 0) {
  633. values = data.float16Value; // byte[]
  634. dataType = 'float16';
  635. } else if (data.rawValue && data.rawValue.length > 0) {
  636. if (data.quantization) {
  637. values = data.rawValue;
  638. dataType = `uint${data.quantization.numberOfBits}`;
  639. } else {
  640. shape = [];
  641. }
  642. }
  643. quantization = data.quantization || null;
  644. }
  645. const tensorType = new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  646. const tensor = new coreml.Tensor(tensorType, values, quantization, 'Weights');
  647. const input = this.metadata.input(type, name);
  648. const visible = input && input.visible === false ? false : true;
  649. const value = { value: new coreml.Value('', null, null, tensor) };
  650. initializers.push({ name, visible, value: [value] });
  651. };
  652. const vector = (value) => {
  653. return (value && Object.keys(value).length === 1 && value.vector) ? value.vector : value;
  654. };
  655. const weights = (type, data) => {
  656. switch (type) {
  657. case 'convolution': {
  658. const weightsShape = [data.outputChannels, data.kernelChannels, data.kernelSize[0], data.kernelSize[1]];
  659. if (data.isDeconvolution) {
  660. weightsShape[0] = data.kernelChannels;
  661. weightsShape[1] = Math.floor(Number(data.outputChannels / (data.nGroups === 0 ? 1 : data.nGroups)));
  662. }
  663. initializer(type, 'weights', weightsShape, data.weights);
  664. if (data.hasBias) {
  665. initializer(type, 'bias', [data.outputChannels], data.bias);
  666. }
  667. return { 'weights': true, 'bias': data.hasBias };
  668. }
  669. case 'innerProduct':
  670. initializer(type, 'weights', [data.outputChannels, data.inputChannels], data.weights);
  671. if (data.hasBias) {
  672. initializer(type, 'bias', [data.outputChannels], data.bias);
  673. }
  674. return { 'weights': true, 'bias': data.hasBias };
  675. case 'batchnorm':
  676. initializer(type, 'gamma', [data.channels], data.gamma);
  677. initializer(type, 'beta', [data.channels], data.beta);
  678. if (data.mean) {
  679. initializer(type, 'mean', [data.channels], data.mean);
  680. }
  681. if (data.variance) {
  682. initializer(type, 'variance', [data.channels], data.variance);
  683. }
  684. return { 'gamma': true, 'beta': true, 'mean': true, 'variance': true };
  685. case 'embedding':
  686. initializer(type, 'weights', [data.inputDim, data.outputChannels], data.weights);
  687. return { 'weights': true };
  688. case 'loadConstant':
  689. case 'loadConstantND':
  690. initializer(type, 'data', data.shape, data.data);
  691. return { 'data': true };
  692. case 'scale':
  693. initializer(type, 'scale', data.shapeScale, data.scale);
  694. if (data.hasBias) {
  695. initializer(type, 'bias', data.shapeBias, data.bias);
  696. }
  697. return { 'scale': true, 'bias': data.hasBias };
  698. case 'bias':
  699. initializer(type, 'bias', data.shape, data.bias);
  700. return { 'bias': true };
  701. case 'simpleRecurrent':
  702. initializer(type, 'weights', [data.outputVectorSize, data.inputVectorSize], data.weightMatrix);
  703. initializer(type, 'recurrent', [data.outputVectorSize, data.inputVectorSize], data.recursionMatrix);
  704. if (data.hasBiasVectors) {
  705. initializer(type, 'bias', [data.outputVectorSize], data.biasVector);
  706. }
  707. return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
  708. case 'gru': {
  709. const recursionMatrixShape = [data.outputVectorSize, data.outputVectorSize];
  710. const weightMatrixShape = [data.outputVectorSize, data.inputVectorSize];
  711. const biasVectorShape = [data.outputVectorSize];
  712. initializer(type, 'updateGateWeightMatrix', weightMatrixShape, data.updateGateWeightMatrix);
  713. initializer(type, 'resetGateWeightMatrix', weightMatrixShape, data.resetGateWeightMatrix);
  714. initializer(type, 'outputGateWeightMatrix', weightMatrixShape, data.outputGateWeightMatrix);
  715. initializer(type, 'updateGateRecursionMatrix', recursionMatrixShape, data.updateGateRecursionMatrix);
  716. initializer(type, 'resetGateRecursionMatrix', recursionMatrixShape, data.resetGateRecursionMatrix);
  717. initializer(type, 'outputGateRecursionMatrix', recursionMatrixShape, data.outputGateRecursionMatrix);
  718. if (data.hasBiasVectors) {
  719. initializer(type, 'updateGateBiasVector', biasVectorShape, data.updateGateBiasVector);
  720. initializer(type, 'resetGateBiasVector', biasVectorShape, data.resetGateBiasVector);
  721. initializer(type, 'outputGateBiasVector', biasVectorShape, data.outputGateBiasVector);
  722. }
  723. return {
  724. 'updateGateWeightMatrix': true, 'resetGateWeightMatrix': true, 'outputGateWeightMatrix': true,
  725. 'updateGateRecursionMatrix': true, 'resetGateRecursionMatrix': true, 'outputGateRecursionMatrix': true,
  726. 'updateGateBiasVector': data.hasBiasVectors, 'resetGateBiasVector': data.hasBiasVectors, 'outputGateBiasVector': data.hasBiasVectors
  727. };
  728. }
  729. case 'uniDirectionalLSTM':
  730. case 'biDirectionalLSTM': {
  731. const count = (type === 'uniDirectionalLSTM') ? 1 : 2;
  732. const h = data.outputVectorSize;
  733. const x = data.inputVectorSize;
  734. for (let i = 0; i < count; i++) {
  735. const weights = count === 1 ? data.weightParams : data.weightParams[i];
  736. const suffix = (i === 0) ? '' : '_rev';
  737. initializer(type, `inputGateWeightMatrix${suffix}`, [h,x], weights.inputGateWeightMatrix);
  738. initializer(type, `forgetGateWeightMatrix${suffix}`, [h,x], weights.forgetGateWeightMatrix);
  739. initializer(type, `blockInputWeightMatrix${suffix}`, [h,x], weights.blockInputWeightMatrix);
  740. initializer(type, `outputGateWeightMatrix${suffix}`, [h,x], weights.outputGateWeightMatrix);
  741. initializer(type, `inputGateRecursionMatrix${suffix}`, [h,h], weights.inputGateRecursionMatrix);
  742. initializer(type, `forgetGateRecursionMatrix${suffix}`, [h,h],weights.forgetGateRecursionMatrix);
  743. initializer(type, `blockInputRecursionMatrix${suffix}`, [h,h], weights.blockInputRecursionMatrix);
  744. initializer(type, `outputGateRecursionMatrix${suffix}`, [h,h], weights.outputGateRecursionMatrix);
  745. if (data.params.hasBiasVectors) {
  746. initializer(type, `inputGateBiasVector${suffix}`, [h], weights.inputGateBiasVector);
  747. initializer(type, `forgetGateBiasVector${suffix}`, [h], weights.forgetGateBiasVector);
  748. initializer(type, `blockInputBiasVector${suffix}`, [h], weights.blockInputBiasVector);
  749. initializer(type, `outputGateBiasVector${suffix}`, [h], weights.outputGateBiasVector);
  750. }
  751. if (data.params.hasPeepholeVectors) {
  752. initializer(type, `inputGatePeepholeVector${suffix}`, [h], weights.inputGatePeepholeVector);
  753. initializer(type, `forgetGatePeepholeVector${suffix}`, [h], weights.forgetGatePeepholeVector);
  754. initializer(type, `outputGatePeepholeVector${suffix}`, [h], weights.outputGatePeepholeVector);
  755. }
  756. }
  757. return { 'weightParams': true };
  758. }
  759. case 'dictVectorizer':
  760. data.stringToIndex = vector(data.stringToIndex);
  761. return {};
  762. case 'wordTagger':
  763. data.modelParameterData = Array.from(data.modelParameterData);
  764. data.stringTags = vector(data.stringTags);
  765. return { tokensOutputFeatureName: true, tokenTagsOutputFeatureName: true, tokenLengthsOutputFeatureName: true, tokenLocationsOutputFeatureName: true };
  766. case 'textClassifier':
  767. data.modelParameterData = Array.from(data.modelParameterData);
  768. data.stringClassLabels = vector(data.stringClassLabels);
  769. return {};
  770. case 'nonMaximumSuppression':
  771. data.stringClassLabels = vector(data.stringClassLabels);
  772. return {};
  773. default:
  774. return {};
  775. }
  776. };
  777. if (data) {
  778. const attributes = obj.attributes;
  779. const map = weights(type, data, initializers);
  780. for (const [name, value] of Object.entries(data)) {
  781. if (!map[name]) {
  782. attributes[name] = value;
  783. }
  784. }
  785. switch (obj.type) {
  786. case 'loop':
  787. attributes.bodyNetwork = this.network(attributes.bodyNetwork);
  788. attributes.conditionNetwork = this.network(attributes.conditionNetwork);
  789. break;
  790. case 'branch':
  791. attributes.ifBranch = this.network(attributes.ifBranch);
  792. attributes.elseBranch = this.network(attributes.elseBranch);
  793. break;
  794. default:
  795. break;
  796. }
  797. }
  798. const metadata = this.metadata.type(type);
  799. for (let i = 0; i < inputs.length;) {
  800. const input = metadata && metadata.inputs && i < metadata.inputs.length ? metadata.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
  801. const count = input.type === 'Tensor[]' ? inputs.length - i : 1;
  802. const values = inputs.slice(i, i + count);
  803. obj.inputs.push({ name: input.name, visible: true, value: values });
  804. i += count;
  805. }
  806. obj.inputs.push(...initializers);
  807. for (let i = 0; i < outputs.length;) {
  808. const output = metadata && metadata.outputs && i < metadata.outputs.length ? metadata.outputs[i] : { name: i === 0 ? 'output' : i.toString() };
  809. const count = output.type === 'Tensor[]' ? outputs.length - i : 1;
  810. const args = outputs.slice(i, i + count);
  811. obj.outputs.push({ name: output.name, visible: true, value: args });
  812. i += count;
  813. }
  814. this.nodes.push(obj);
  815. return obj;
  816. }
  817. model(model, group, description) {
  818. this.groups |= group.length > 0;
  819. const shortDescription = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : '';
  820. switch (model.Type) {
  821. case 'neuralNetworkClassifier': {
  822. const neuralNetworkClassifier = model.neuralNetworkClassifier;
  823. for (const layer of neuralNetworkClassifier.layers) {
  824. const type = layer.layer;
  825. this.node(group, type, layer.name, group === '' ? '' : shortDescription, layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  826. }
  827. this.updateClassifierOutput(group, neuralNetworkClassifier, description);
  828. this.updatePreprocessing(group, neuralNetworkClassifier.preprocessing, description);
  829. return 'Neural Network Classifier';
  830. }
  831. case 'neuralNetwork': {
  832. const neuralNetwork = model.neuralNetwork;
  833. for (const layer of neuralNetwork.layers) {
  834. this.node(group, layer.layer, layer.name, group === '' ? '' : shortDescription, layer[layer.layer], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  835. }
  836. this.updatePreprocessing(group, neuralNetwork.preprocessing, description);
  837. return 'Neural Network';
  838. }
  839. case 'neuralNetworkRegressor': {
  840. const neuralNetworkRegressor = model.neuralNetworkRegressor;
  841. for (const layer of neuralNetworkRegressor.layers) {
  842. this.node(group, layer.layer, layer.name, shortDescription, layer[layer.layer], layer.input, layer.output);
  843. }
  844. this.updatePreprocessing(group, neuralNetworkRegressor, description);
  845. return 'Neural Network Regressor';
  846. }
  847. case 'pipeline': {
  848. for (let i = 0; i < model.pipeline.models.length; i++) {
  849. this.model(model.pipeline.models[i], `${group ? (`${group}/`) : ''}pipeline[${i}]`, description);
  850. }
  851. return 'Pipeline';
  852. }
  853. case 'pipelineClassifier': {
  854. for (let i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) {
  855. this.model(model.pipelineClassifier.pipeline.models[i], `${group ? (`${group}/`) : ''}pipelineClassifier[${i}]`, description);
  856. }
  857. return 'Pipeline Classifier';
  858. }
  859. case 'pipelineRegressor': {
  860. for (let i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) {
  861. this.model(model.pipelineRegressor.pipeline.models[i], `${group ? (`${group}/`) : ''}pipelineRegressor[${i}]`, description);
  862. }
  863. return 'Pipeline Regressor';
  864. }
  865. case 'glmClassifier': {
  866. this.node(group, 'glmClassifier', null, shortDescription,
  867. {
  868. classEncoding: model.glmClassifier.classEncoding,
  869. offset: model.glmClassifier.offset,
  870. weights: model.glmClassifier.weights
  871. },
  872. [model.description.input[0].name],
  873. [model.description.output[0].name]);
  874. this.updateClassifierOutput(group, model.glmClassifier, description);
  875. return 'Generalized Linear Classifier';
  876. }
  877. case 'glmRegressor': {
  878. this.node(group, 'glmRegressor', null, shortDescription,
  879. model.glmRegressor,
  880. [model.description.input[0].name],
  881. [model.description.output[0].name]);
  882. return 'Generalized Linear Regressor';
  883. }
  884. case 'treeEnsembleClassifier': {
  885. this.node(group, 'treeEnsembleClassifier', null, shortDescription,
  886. model.treeEnsembleClassifier.treeEnsemble,
  887. [model.description.input[0].name],
  888. [model.description.output[0].name]);
  889. this.updateClassifierOutput(group, model.treeEnsembleClassifier, description);
  890. return 'Tree Ensemble Classifier';
  891. }
  892. case 'treeEnsembleRegressor': {
  893. this.node(group, 'treeEnsembleRegressor', null, shortDescription,
  894. model.treeEnsembleRegressor.treeEnsemble,
  895. [model.description.input[0].name],
  896. [model.description.output[0].name]);
  897. return 'Tree Ensemble Regressor';
  898. }
  899. case 'supportVectorClassifier': {
  900. this.node(group, 'supportVectorClassifier', null, shortDescription,
  901. {
  902. coefficients: model.supportVectorClassifier.coefficients,
  903. denseSupportVectors: model.supportVectorClassifier.denseSupportVectors,
  904. kernel: model.supportVectorClassifier.kernel,
  905. numberOfSupportVectorsPerClass: model.supportVectorClassifier.numberOfSupportVectorsPerClass,
  906. probA: model.supportVectorClassifier.probA,
  907. probB: model.supportVectorClassifier.probB,
  908. rho: model.supportVectorClassifier.rho,
  909. supportVectors: model.supportVectorClassifier.supportVectors
  910. },
  911. [model.description.input[0].name],
  912. [model.description.output[0].name]);
  913. this.updateClassifierOutput(group, model.supportVectorClassifier, description);
  914. return 'Support Vector Classifier';
  915. }
  916. case 'supportVectorRegressor': {
  917. this.node(group, 'supportVectorRegressor', null, shortDescription,
  918. {
  919. coefficients: model.supportVectorRegressor.coefficients,
  920. kernel: model.supportVectorRegressor.kernel,
  921. rho: model.supportVectorRegressor.rho,
  922. supportVectors: model.supportVectorRegressor.supportVectors
  923. },
  924. [model.description.input[0].name],
  925. [model.description.output[0].name]);
  926. return 'Support Vector Regressor';
  927. }
  928. case 'oneHotEncoder': {
  929. const categoryType = model.oneHotEncoder.CategoryType;
  930. const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse };
  931. oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType];
  932. this.node(group, 'oneHotEncoder', null, shortDescription,
  933. oneHotEncoderParams,
  934. [model.description.input[0].name],
  935. [model.description.output[0].name]);
  936. return 'One Hot Encoder';
  937. }
  938. case 'imputer': {
  939. const imputedValue = model.imputer.ImputedValue;
  940. const replaceValue = model.imputer.ReplaceValue;
  941. const imputerParams = {};
  942. imputerParams[imputedValue] = model.imputer[imputedValue];
  943. imputerParams[replaceValue] = model.imputer[replaceValue];
  944. this.node(group, 'oneHotEncoder', null, shortDescription,
  945. imputerParams,
  946. [model.description.input[0].name],
  947. [model.description.output[0].name]);
  948. return 'Imputer';
  949. }
  950. case 'featureVectorizer': {
  951. this.node(group, 'featureVectorizer', null, shortDescription,
  952. model.featureVectorizer,
  953. model.description.input.map((item) => item.name),
  954. [model.description.output[0].name]);
  955. return 'Feature Vectorizer';
  956. }
  957. case 'dictVectorizer': {
  958. this.node(group, 'dictVectorizer', null, shortDescription,
  959. model.dictVectorizer,
  960. [model.description.input[0].name],
  961. [model.description.output[0].name]);
  962. return 'Dictionary Vectorizer';
  963. }
  964. case 'scaler': {
  965. this.node(group, 'scaler', null, shortDescription,
  966. model.scaler,
  967. [model.description.input[0].name],
  968. [model.description.output[0].name]);
  969. return 'Scaler';
  970. }
  971. case 'categoricalMapping': {
  972. this.node(group, 'categoricalMapping', null, shortDescription,
  973. model.categoricalMapping,
  974. [model.description.input[0].name],
  975. [model.description.output[0].name]);
  976. return 'Categorical Mapping';
  977. }
  978. case 'normalizer': {
  979. this.node(group, 'normalizer', null, shortDescription,
  980. model.normalizer,
  981. [model.description.input[0].name],
  982. [model.description.output[0].name]);
  983. return 'Normalizer';
  984. }
  985. case 'arrayFeatureExtractor': {
  986. this.node(group, 'arrayFeatureExtractor', null, shortDescription,
  987. { extractIndex: model.arrayFeatureExtractor.extractIndex },
  988. [model.description.input[0].name],
  989. [model.description.output[0].name]);
  990. return 'Array Feature Extractor';
  991. }
  992. case 'nonMaximumSuppression': {
  993. const nonMaximumSuppressionParams = {
  994. pickTop: model.nonMaximumSuppression.pickTop,
  995. stringClassLabels: model.nonMaximumSuppression.stringClassLabels,
  996. iouThreshold: model.nonMaximumSuppression.iouThreshold,
  997. confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold
  998. };
  999. this.node(group, 'nonMaximumSuppression', null, shortDescription,
  1000. nonMaximumSuppressionParams,
  1001. [
  1002. model.nonMaximumSuppression.confidenceInputFeatureName,
  1003. model.nonMaximumSuppression.coordinatesInputFeatureName,
  1004. model.nonMaximumSuppression.iouThresholdInputFeatureName,
  1005. model.nonMaximumSuppression.confidenceThresholdInputFeatureName,
  1006. ],
  1007. [
  1008. model.nonMaximumSuppression.confidenceOutputFeatureName,
  1009. model.nonMaximumSuppression.coordinatesOutputFeatureName
  1010. ]);
  1011. return 'Non Maximum Suppression';
  1012. }
  1013. case 'wordTagger': {
  1014. this.node(group, 'wordTagger', null, shortDescription,
  1015. model.wordTagger,
  1016. [model.description.input[0].name],
  1017. [
  1018. model.wordTagger.tokensOutputFeatureName,
  1019. model.wordTagger.tokenTagsOutputFeatureName,
  1020. model.wordTagger.tokenLocationsOutputFeatureName,
  1021. model.wordTagger.tokenLengthsOutputFeatureName
  1022. ]);
  1023. return 'Word Tagger';
  1024. }
  1025. case 'textClassifier': {
  1026. this.node(group, 'textClassifier', null, shortDescription,
  1027. model.textClassifier,
  1028. [model.description.input[0].name],
  1029. [model.description.output[0].name]);
  1030. return 'Text Classifier';
  1031. }
  1032. case 'visionFeaturePrint': {
  1033. const visionFeaturePrintParams = {
  1034. scene: model.visionFeaturePrint.scene
  1035. };
  1036. this.node(group, 'visionFeaturePrint', null, shortDescription,
  1037. visionFeaturePrintParams,
  1038. [model.description.input[0].name],
  1039. [model.description.output[0].name]);
  1040. return 'Vision Feature Print';
  1041. }
  1042. case 'soundAnalysisPreprocessing': {
  1043. this.node(group, 'soundAnalysisPreprocessing', null, shortDescription,
  1044. model.soundAnalysisPreprocessing,
  1045. [model.description.input[0].name],
  1046. [model.description.output[0].name]);
  1047. return 'Sound Analysis Preprocessing';
  1048. }
  1049. case 'kNearestNeighborsClassifier': {
  1050. this.node(group, 'kNearestNeighborsClassifier', null, shortDescription,
  1051. model.kNearestNeighborsClassifier,
  1052. [model.description.input[0].name],
  1053. [model.description.output[0].name]);
  1054. this.updateClassifierOutput(group, model.kNearestNeighborsClassifier, description);
  1055. return 'Nearest Neighbors Classifier';
  1056. }
  1057. case 'itemSimilarityRecommender': {
  1058. this.node(group, 'itemSimilarityRecommender', null, shortDescription,
  1059. {
  1060. itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector,
  1061. itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities
  1062. },
  1063. model.description.input.map((feature) => feature.name),
  1064. model.description.output.map((feature) => feature.name));
  1065. return 'Item Similarity Recommender';
  1066. }
  1067. case 'audioFeaturePrint': {
  1068. this.node(group, 'audioFeaturePrint', null, shortDescription,
  1069. model.audioFeaturePrint,
  1070. [model.description.input[0].name],
  1071. [model.description.output[0].name]);
  1072. return 'Audio Feature Print';
  1073. }
  1074. case 'linkedModel': {
  1075. this.node(group, 'linkedModel', null, shortDescription,
  1076. model.linkedModel.linkedModelFile,
  1077. [model.description.input[0].name],
  1078. [model.description.output[0].name]);
  1079. return 'Linked Model';
  1080. }
  1081. case 'customModel': {
  1082. this.node(group, 'customModel', null, shortDescription,
  1083. { className: model.customModel.className, parameters: model.customModel.parameters },
  1084. [model.description.input[0].name],
  1085. [model.description.output[0].name]);
  1086. return 'customModel';
  1087. }
  1088. case 'mlProgram': {
  1089. return this.program(model.mlProgram, group);
  1090. }
  1091. default: {
  1092. throw new coreml.Error(`Unsupported model type '${JSON.stringify(Object.keys(model))}'.`);
  1093. }
  1094. }
  1095. }
  1096. updateClassifierOutput(group, classifier, description) {
  1097. let labelProbabilityLayerName = classifier.labelProbabilityLayerName;
  1098. if (!labelProbabilityLayerName && this.nodes.length > 0) {
  1099. const node = this.nodes.slice(-1).pop();
  1100. if (node && node.outputs.length === 1 && node.outputs[0].value.length === 1) {
  1101. labelProbabilityLayerName = node.outputs[0].value[0].name;
  1102. }
  1103. }
  1104. let predictedFeatureName = description.predictedFeatureName;
  1105. let predictedProbabilitiesName = description.predictedProbabilitiesName;
  1106. if ((predictedFeatureName || predictedProbabilitiesName) && labelProbabilityLayerName && classifier.ClassLabels) {
  1107. predictedFeatureName = predictedFeatureName ? predictedFeatureName : '?';
  1108. predictedProbabilitiesName = predictedProbabilitiesName ? predictedProbabilitiesName : '?';
  1109. const labelProbabilityInput = `${labelProbabilityLayerName}:labelProbabilityLayerName`;
  1110. const values = new Set();
  1111. for (const node of this.nodes) {
  1112. for (const output of node.outputs) {
  1113. for (const value of output.value) {
  1114. if (value.name === labelProbabilityLayerName) {
  1115. value.name = labelProbabilityInput;
  1116. values.add(value);
  1117. }
  1118. }
  1119. }
  1120. }
  1121. this.values.set(labelProbabilityInput, this.values.get(labelProbabilityLayerName));
  1122. this.values.delete(labelProbabilityLayerName);
  1123. const type = classifier.ClassLabels;
  1124. const node = {
  1125. // group: this._group,
  1126. type,
  1127. name: null,
  1128. description: '',
  1129. attributes: classifier[type] || {}
  1130. };
  1131. node.inputs = [
  1132. { name: 'input', visible: true, value: Array.from(values) }
  1133. ];
  1134. node.outputs = [
  1135. { name: 'probabilities', visible: true, value: [this.output(predictedProbabilitiesName)] },
  1136. { name: 'feature', visible: true, value: [this.output(predictedFeatureName)] }
  1137. ];
  1138. this.nodes.push(node);
  1139. }
  1140. }
  1141. updatePreprocessing(group, preprocessings, description) {
  1142. if (preprocessings && preprocessings.length > 0) {
  1143. const preprocessingInput = description.input[0].name;
  1144. const inputNodes = [];
  1145. for (const node of this.nodes) {
  1146. if (node.inputs.some((input) => Array.isArray(input.value) && input.value.some((arg) => arg.name === preprocessingInput))) {
  1147. inputNodes.push(node);
  1148. }
  1149. }
  1150. let currentOutput = preprocessingInput;
  1151. let preprocessorOutput = null;
  1152. let preprocessorIndex = 0;
  1153. for (const preprocessing of preprocessings) {
  1154. const input = preprocessing.featureName ? preprocessing.featureName : currentOutput;
  1155. currentOutput = `${preprocessingInput}:${preprocessorIndex}`;
  1156. const preprocessor = preprocessing.preprocessor;
  1157. const node = this.node(group, preprocessor, null, '', preprocessing[preprocessor], [input], [currentOutput]);
  1158. [preprocessorOutput] = node.outputs[0].value;
  1159. preprocessorIndex++;
  1160. }
  1161. for (const node of inputNodes) {
  1162. for (const input of node.inputs) {
  1163. if (Array.isArray(input.value)) {
  1164. for (let i = 0; i < input.value.length; i++) {
  1165. if (input.value[i].name === preprocessingInput) {
  1166. input.value[i] = preprocessorOutput;
  1167. }
  1168. }
  1169. }
  1170. }
  1171. }
  1172. }
  1173. }
  1174. program(program, group) {
  1175. // need to handle functions other than main?
  1176. const name = this.name || 'main';
  1177. const main = program.functions[name];
  1178. // need to handle more than one block specialization?
  1179. const block_specializations = main.block_specializations;
  1180. const key = Object.keys(block_specializations).filter((key) => key.startsWith('CoreML')).shift();
  1181. const block = block_specializations[key];
  1182. const convertValue = (value) => {
  1183. switch (value.value) {
  1184. case 'immediateValue': {
  1185. const tensor = value.immediateValue.tensor;
  1186. const type = coreml.Utility.valueType(value.type);
  1187. let values = null;
  1188. switch (tensor.value) {
  1189. case 'ints':
  1190. values = tensor.ints.values;
  1191. break;
  1192. case 'strings':
  1193. values = tensor.strings.values;
  1194. break;
  1195. case 'bools':
  1196. values = tensor.bools.values;
  1197. break;
  1198. case 'floats':
  1199. values = tensor.floats.values;
  1200. break;
  1201. case 'bytes':
  1202. values = tensor.bytes.values;
  1203. break;
  1204. default:
  1205. throw new coreml.Error(`Unsupported tensor value '${tensor.value}'.`);
  1206. }
  1207. if (type.shape.dimensions.length === 0) {
  1208. [values] = values;
  1209. }
  1210. return values;
  1211. }
  1212. case 'blobFileValue': {
  1213. const type = coreml.Utility.valueType(value.type);
  1214. const blob = value.blobFileValue;
  1215. const offset = Number(blob.offset);
  1216. const file = blob.fileName;
  1217. let data = null;
  1218. const stream = this.weights.get(file);
  1219. if (stream) {
  1220. stream.seek(offset);
  1221. const buffer = stream.read(32);
  1222. const reader = base.BinaryReader.open(buffer);
  1223. const signature = reader.uint32();
  1224. if (signature === 0xdeadbeef) {
  1225. reader.uint32(); // dataType
  1226. const size = reader.uint64().toNumber();
  1227. const offset = reader.uint64().toNumber();
  1228. stream.seek(offset);
  1229. const length = (type.shape.dimensions || []).reduce((a, b) => a * b, 1);
  1230. switch (type.dataType) {
  1231. case 'float32': {
  1232. const buffer = stream.read(size);
  1233. data = new Float32Array(buffer.buffer, buffer.byteOffset, length).slice();
  1234. break;
  1235. }
  1236. case 'float16':
  1237. case 'int1': case 'int2': case 'int3': case 'int4': case 'int6': case 'int8': case 'int32':
  1238. case 'uint1': case 'uint2': case 'uint3': case 'uint4': case 'uint6': case 'uint8': case 'uint16': {
  1239. data = stream.read(size);
  1240. break;
  1241. }
  1242. default:
  1243. throw new coreml.Error(`Unsupported blob data type '${type.dataType}'.`);
  1244. }
  1245. }
  1246. }
  1247. return new coreml.Tensor(type, data, null, 'Blob');
  1248. }
  1249. default: {
  1250. throw new coreml.Error(`Unsupported value '${value.value}'.`);
  1251. }
  1252. }
  1253. };
  1254. const operations = block.operations.map((op) => {
  1255. const operation = {
  1256. type: op.type,
  1257. attributes: {}
  1258. };
  1259. for (const [key, value] of Object.entries(op.attributes)) {
  1260. operation.attributes[key] = convertValue(value);
  1261. }
  1262. operation.inputs = Object.entries(op.inputs).map(([name, input]) => {
  1263. const value = input.arguments.map((argument) => {
  1264. if (argument.value && argument.value.value && argument.value.blobFileValue) {
  1265. return { name: '', value: convertValue(argument.value) };
  1266. }
  1267. if (argument.name) {
  1268. const value = this.input(argument.name);
  1269. value.to.push(operation);
  1270. return value;
  1271. }
  1272. return { value: argument.value };
  1273. });
  1274. return { name, value };
  1275. });
  1276. operation.outputs = op.outputs.map((output) => {
  1277. const value = this.input(output.name);
  1278. value.type = coreml.Utility.valueType(output.type);
  1279. value.from.push(operation);
  1280. return { name: 'output', value: [value] };
  1281. });
  1282. return operation;
  1283. });
  1284. for (const op of operations) {
  1285. if (op.type === 'const' && op.inputs.length === 0 &&
  1286. op.outputs.length === 1 && op.outputs[0].value.length === 1) {
  1287. const [value] = op.outputs[0].value;
  1288. if (op.attributes && op.attributes.val) {
  1289. const type = value.type;
  1290. const data = op.attributes.val;
  1291. if (data instanceof Uint8Array && data.length === 2 &&
  1292. type.dataType === 'float16' && type.shape.dimensions.length === 0) {
  1293. const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
  1294. value.value = view.getFloat16(0, true);
  1295. } else {
  1296. value.value = data;
  1297. }
  1298. value.const = true;
  1299. op.delete = true;
  1300. }
  1301. }
  1302. }
  1303. for (const op of operations) {
  1304. for (const input of op.inputs) {
  1305. if (input.value.length > 1 && input.value.some((argument) => argument.const)) {
  1306. if (!input.value.every((argument) => argument.value instanceof coreml.Tensor)) {
  1307. for (const value of input.value) {
  1308. for (const from of value.from) {
  1309. from.delete = false;
  1310. }
  1311. delete value.value;
  1312. }
  1313. }
  1314. }
  1315. }
  1316. }
  1317. for (const op of operations.filter((op) => !op.delete)) {
  1318. op.inputs = op.inputs.filter((input) => {
  1319. if (input.value.every((value) => value.value === undefined || value.value instanceof coreml.Tensor)) {
  1320. return true;
  1321. }
  1322. op.attributes[input.name] = input.value.length === 1 ?
  1323. input.value[0].value :
  1324. input.value.map((argument) => argument.value[0]);
  1325. return false;
  1326. });
  1327. }
  1328. const mapValue = (name, value) => {
  1329. if (value.value instanceof coreml.Tensor) {
  1330. value.initializer = value.value;
  1331. delete value.value;
  1332. if (name === '') {
  1333. this.values.set(value, value);
  1334. return value;
  1335. }
  1336. }
  1337. if (!this.values.has(name)) {
  1338. this.values.set(name, value);
  1339. } else if ((value.type && !value.type.equals(this.values.get(name).type)) ||
  1340. (value.initializer && value.initializer !== this.values.get(name).initializer)) {
  1341. throw new coreml.Error(`Duplicate value '${name}'.`);
  1342. }
  1343. return this.values.get(name);
  1344. };
  1345. for (const op of operations.filter((op) => !op.delete)) {
  1346. for (const argument of op.inputs) {
  1347. for (const value of argument.value) {
  1348. mapValue(value.name, value);
  1349. }
  1350. }
  1351. for (const argument of op.outputs) {
  1352. for (const value of argument.value) {
  1353. mapValue(value.name, value);
  1354. }
  1355. }
  1356. }
  1357. for (const op of operations.filter((op) => !op.delete)) {
  1358. op.group = group;
  1359. op.type = `program:${op.type}`;
  1360. const metadata = this.metadata.type(op.type);
  1361. if (metadata && Array.isArray(metadata.inputs)) {
  1362. const map = new Map(metadata.inputs.map((input, index) => [input.name, index + 1]));
  1363. op.inputs.sort((a, b) => (map.get(a.name) || map.size) - (map.get(b.name) || map.size));
  1364. }
  1365. this.nodes.push(op);
  1366. }
  1367. return 'ML Program';
  1368. }
  1369. };
  1370. coreml.Utility = class {
  1371. static enum(name, value) {
  1372. let type = coreml.proto;
  1373. const parts = name.split('.');
  1374. while (type && parts.length > 0) {
  1375. type = type[parts.shift()];
  1376. }
  1377. if (type) {
  1378. coreml.Utility._enumKeyMap = coreml.Utility._enumKeyMap || new Map();
  1379. if (!coreml.Utility._enumKeyMap.has(name)) {
  1380. const map = new Map(Object.entries(type).map(([key, value]) => [value, key]));
  1381. coreml.Utility._enumKeyMap.set(name, map);
  1382. }
  1383. const map = coreml.Utility._enumKeyMap.get(name);
  1384. if (map.has(value)) {
  1385. return map.get(value);
  1386. }
  1387. }
  1388. return value;
  1389. }
  1390. static featureType(type) {
  1391. let result = '?';
  1392. if (type) {
  1393. switch (type.Type) {
  1394. case 'arrayType':
  1395. case 'multiArrayType': {
  1396. const arrayType = type[type.Type];
  1397. let shape = new coreml.TensorShape([]);
  1398. if (arrayType.shape && arrayType.shape.length > 0) {
  1399. shape = new coreml.TensorShape(arrayType.shape.map((dim) => Number(dim)));
  1400. }
  1401. let dataType = '';
  1402. const ArrayDataType = coreml.proto.ArrayFeatureType.ArrayDataType;
  1403. switch (arrayType.dataType) {
  1404. case ArrayDataType.INVALID_ARRAY_DATA_TYPE:
  1405. dataType = '?';
  1406. break;
  1407. case ArrayDataType.FLOAT16:
  1408. dataType = 'float16';
  1409. break;
  1410. case ArrayDataType.FLOAT32:
  1411. dataType = 'float32';
  1412. break;
  1413. case ArrayDataType.DOUBLE:
  1414. dataType = 'float64';
  1415. break;
  1416. case ArrayDataType.INT32:
  1417. dataType = 'int32';
  1418. break;
  1419. case ArrayDataType.INT8:
  1420. dataType = 'int8';
  1421. break;
  1422. default:
  1423. throw new coreml.Error(`Unsupported array data type '${arrayType.dataType}'.`);
  1424. }
  1425. result = new coreml.TensorType(dataType, shape);
  1426. break;
  1427. }
  1428. case 'stringType': {
  1429. result = new coreml.TensorType('string');
  1430. break;
  1431. }
  1432. case 'doubleType': {
  1433. result = new coreml.TensorType('float64');
  1434. break;
  1435. }
  1436. case 'int64Type': {
  1437. result = new coreml.TensorType('int64');
  1438. break;
  1439. }
  1440. case 'dictionaryType': {
  1441. result = new coreml.MapType(type.dictionaryType.KeyType.replace('KeyType', ''), 'float64');
  1442. break;
  1443. }
  1444. case 'sequenceType': {
  1445. result = new coreml.SequenceType(coreml.Utility.featureType(type[type.Type]));
  1446. break;
  1447. }
  1448. case 'imageType': {
  1449. result = new coreml.ImageType(type.imageType.colorSpace, type.imageType.width, type.imageType.height);
  1450. break;
  1451. }
  1452. case 'stateType': {
  1453. result = new coreml.StateType(coreml.Utility.featureType(type.stateType));
  1454. break;
  1455. }
  1456. default: {
  1457. throw new coreml.Error(`Unsupported feature type '${type.Type}'.`);
  1458. }
  1459. }
  1460. if (type.isOptional) {
  1461. result = new coreml.OptionalType(result);
  1462. }
  1463. }
  1464. return result;
  1465. }
  1466. static tensorType(type) {
  1467. if (!coreml.Utility._dataTypes) {
  1468. coreml.Utility._dataTypes = new Map(Object.entries(coreml.proto.MILSpec.DataType).map((([key, value]) => [value, key.toLowerCase()])));
  1469. coreml.Utility._dataTypes.delete(0);
  1470. coreml.Utility._dataTypes.set(1, 'boolean');
  1471. }
  1472. const shape = type.dimensions.map((dim) => dim.constant ? dim.constant.size : '?');
  1473. const dataType = coreml.Utility._dataTypes.get(type.dataType);
  1474. if (!dataType) {
  1475. throw new coreml.Error(`Unsupported data type '${type.dataType}'.`);
  1476. }
  1477. return new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  1478. }
  1479. static valueType(type) {
  1480. switch (type.type) {
  1481. case 'tensorType':
  1482. return coreml.Utility.tensorType(type.tensorType);
  1483. case 'listType':
  1484. return new coreml.ListType(coreml.Utility.valueType(type.listType.type));
  1485. case 'dictionaryType':
  1486. return new coreml.MapType(coreml.Utility.valueType(type.dictionaryType.keyType), coreml.Utility.valueType(type.dictionaryType.valueType));
  1487. default:
  1488. throw new coreml.Error(`Unsupported value type '${type.type}'.`);
  1489. }
  1490. }
  1491. };
  1492. coreml.Error = class extends Error {
  1493. constructor(message) {
  1494. super(message);
  1495. this.name = 'Error loading Core ML model.';
  1496. }
  1497. };
  1498. export const ModelFactory = coreml.ModelFactory;