coreml.js 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460
  1. import * as base from './base.js';
  2. import * as protobuf from './protobuf.js';
  3. const coreml = {};
  4. coreml.ModelFactory = class {
  5. match(context) {
  6. const stream = context.stream;
  7. const identifier = context.identifier.toLowerCase();
  8. const extension = identifier.split('.').pop().toLowerCase();
  9. const tags = context.tags('pb');
  10. if (tags.get(1) === 0 && tags.get(2) === 2) {
  11. if (extension === 'pb') {
  12. const tags = context.tags('pb+');
  13. const keys = Object.keys(tags).map((key) => parseInt(key, 10));
  14. const match = (key) =>
  15. (key >= 200 && key < 220) ||
  16. (key >= 300 && key < 320) ||
  17. (key >= 400 && key < 420) ||
  18. (key >= 500 && key < 520) ||
  19. (key >= 550 && key < 560) ||
  20. (key >= 600 && key < 620) ||
  21. (key === 900) ||
  22. (key >= 2000 && key < 2010) ||
  23. (key === 3000);
  24. if (!keys.some((key) => match(key))) {
  25. return null;
  26. }
  27. }
  28. return 'coreml.pb';
  29. }
  30. if (extension === 'pbtxt') {
  31. const tags = context.tags('pbtxt');
  32. if (tags.has('specificationVersion') && tags.has('description')) {
  33. return 'coreml.pbtxt';
  34. }
  35. }
  36. if (identifier === 'manifest.json') {
  37. const obj = context.peek('json');
  38. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  39. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  40. if (entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel').length === 1)) {
  41. return 'coreml.manifest';
  42. }
  43. }
  44. }
  45. if (identifier === 'metadata.json') {
  46. const obj = context.peek('json');
  47. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  48. return 'coreml.metadata';
  49. }
  50. }
  51. if (identifier === 'featuredescriptions.json') {
  52. const obj = context.peek('json');
  53. if (obj && (obj.Inputs || obj.Outputs)) {
  54. return 'coreml.featuredescriptions';
  55. }
  56. }
  57. if (extension === 'bin' && stream.length > 16) {
  58. const buffer = stream.peek(Math.min(256, stream.length));
  59. for (let i = 0; i < buffer.length - 4; i++) {
  60. const signature = (buffer[i] | buffer[i + 1] << 8 | buffer[i + 2] << 16 | buffer [i + 3] << 24) >>> 0;
  61. if (signature === 0xdeadbeef) {
  62. return 'coreml.weights';
  63. }
  64. }
  65. }
  66. return undefined;
  67. }
  68. async open(context, target) {
  69. await context.require('./coreml-proto');
  70. const metadata = await context.metadata('coreml-metadata.json');
  71. const openBinary = async (stream, context, path, format) => {
  72. let model = null;
  73. try {
  74. coreml.proto = protobuf.get('coreml').CoreML.Specification;
  75. const reader = protobuf.BinaryReader.open(stream);
  76. model = coreml.proto.Model.decode(reader);
  77. } catch (error) {
  78. const message = error && error.message ? error.message : error.toString();
  79. throw new coreml.Error('File format is not coreml.Model (' + message.replace(/\.$/, '') + ').');
  80. }
  81. const weightPaths = new Set();
  82. const walkProgram = (program) => {
  83. for (const func of Object.values(program.functions)) {
  84. for (const block of Object.values(func.block_specializations)) {
  85. for (const operation of block.operations) {
  86. for (const value of Object.values(operation.attributes)) {
  87. if (value.blobFileValue && value.blobFileValue.fileName) {
  88. weightPaths.add(value.blobFileValue.fileName);
  89. }
  90. }
  91. }
  92. }
  93. }
  94. };
  95. const walkModel = (model) => {
  96. if (model.mlProgram) {
  97. walkProgram(model.mlProgram);
  98. }
  99. if (model.pipeline && model.pipeline.models) {
  100. for (const node of model.pipeline.models) {
  101. walkModel(node);
  102. }
  103. }
  104. if (model.pipelineClassifier && model.pipelineClassifier.pipeline && model.pipelineClassifier.pipeline.models) {
  105. for (const node of model.pipelineClassifier.pipeline.models) {
  106. walkModel(node);
  107. }
  108. }
  109. if (model.pipelineRegressor && model.pipelineRegressor.pipeline && model.pipelineRegressor.pipeline.models) {
  110. for (const node of model.pipelineRegressor.pipeline.models) {
  111. walkModel(node);
  112. }
  113. }
  114. };
  115. walkModel(model);
  116. const weights = new Map();
  117. if (weightPaths.size > 0) {
  118. const folder = path.replace(/\/[^/]*$/, '');
  119. const keys = Array.from(weightPaths);
  120. const paths = keys.map((path) => path.replace(/^@model_path\//, folder + '/'));
  121. try {
  122. const contexts = await Promise.all(paths.map((path) => context.fetch(path)));
  123. for (let i = 0; i < keys.length; i++) {
  124. weights.set(keys[i], contexts[i].stream);
  125. }
  126. } catch (error) {
  127. // continue regardless of error
  128. }
  129. }
  130. return new coreml.Model(metadata, format, model, weights);
  131. };
  132. const openText = async (stream) => {
  133. let model = null;
  134. try {
  135. coreml.proto = protobuf.get('coreml').CoreML.Specification;
  136. const reader = protobuf.TextReader.open(stream);
  137. model = coreml.proto.Model.decodeText(reader);
  138. } catch (error) {
  139. const message = error && error.message ? error.message : error.toString();
  140. throw new coreml.Error('File format is not coreml.Model (' + message.replace(/\.$/, '') + ').');
  141. }
  142. const weights = new Map();
  143. return new coreml.Model(metadata, null, model, weights);
  144. };
  145. const openManifest = async (obj, context, path) => {
  146. const entries = Object.values(obj.itemInfoEntries).filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel'));
  147. if (entries.length !== 1) {
  148. throw new coreml.Error('Manifest does not contain Core ML model.');
  149. }
  150. const name = path + 'Data/' + entries[0].path;
  151. const content = await context.fetch(name);
  152. return openBinary(content.stream, context, name, 'Core ML Package');
  153. };
  154. const openManifestStream = async (context, path) => {
  155. const name = path + 'Manifest.json';
  156. const content = await context.fetch(name);
  157. const obj = content.read('json');
  158. return openManifest(obj, context, path);
  159. };
  160. switch (target) {
  161. case 'coreml.pb': {
  162. return openBinary(context.stream, context, context.identifier);
  163. }
  164. case 'coreml.pbtxt': {
  165. return openText(context.stream, context, context.identifier);
  166. }
  167. case 'coreml.manifest': {
  168. const obj = context.peek('json');
  169. return openManifest(obj, context, '');
  170. }
  171. case 'coreml.featuredescriptions':
  172. case 'coreml.metadata': {
  173. return openManifestStream(context, '../../');
  174. }
  175. case 'coreml.weights': {
  176. return openManifestStream(context, '../../../');
  177. }
  178. default: {
  179. throw new coreml.Error("Unsupported Core ML format '" + target + "'.");
  180. }
  181. }
  182. }
  183. };
  184. coreml.Model = class {
  185. constructor(metadata, format, model, weights) {
  186. this.format = (format || 'Core ML') + ' v' + model.specificationVersion.toString();
  187. this.metadata = new Map();
  188. const context = new coreml.Context(metadata, model, weights);
  189. const graph = new coreml.Graph(context);
  190. this.graphs = [ graph ];
  191. if (model.description && model.description.metadata) {
  192. const properties = model.description.metadata;
  193. if (properties.versionString) {
  194. this.version = properties.versionString;
  195. }
  196. if (properties.shortDescription) {
  197. this.description = properties.shortDescription;
  198. }
  199. if (properties.author) {
  200. this.metadata.set('author', properties.author);
  201. }
  202. if (properties.license) {
  203. this.metadata.set('license', properties.license);
  204. }
  205. if (metadata.userDefined && Object.keys(properties.userDefined).length > 0) {
  206. /* empty */
  207. }
  208. }
  209. }
  210. };
  211. coreml.Graph = class {
  212. constructor(context) {
  213. this.name = '';
  214. this.type = context.type;
  215. this.groups = context.groups;
  216. for (const value of context.values.values()) {
  217. const name = value.name;
  218. const type = value.type;
  219. const description = value.description;
  220. const initializer = value.initializer;
  221. if (!value.obj) {
  222. value.obj = new coreml.Value(name, type, description, initializer);
  223. }
  224. }
  225. this.inputs = context.inputs.map((argument) => {
  226. const values = argument.value.map((value) => value.obj);
  227. return new coreml.Argument(argument.name, argument.visible, values);
  228. });
  229. this.outputs = context.outputs.map((argument) => {
  230. const values = argument.value.map((value) => value.obj);
  231. return new coreml.Argument(argument.name, argument.visible, values);
  232. });
  233. for (const obj of context.nodes) {
  234. const attributes = obj.attributes;
  235. switch (obj.type) {
  236. case 'loop':
  237. attributes.conditionNetwork = new coreml.Graph(attributes.conditionNetwork);
  238. attributes.bodyNetwork = new coreml.Graph(attributes.bodyNetwork);
  239. break;
  240. case 'branch':
  241. attributes.ifBranch = new coreml.Graph(attributes.ifBranch);
  242. attributes.elseBranch = new coreml.Graph(attributes.elseBranch);
  243. break;
  244. default:
  245. break;
  246. }
  247. }
  248. this.nodes = context.nodes.map((obj) => new coreml.Node(context, obj));
  249. }
  250. };
  251. coreml.Argument = class {
  252. constructor(name, visible, value) {
  253. this.name = name;
  254. this.visible = visible;
  255. this.value = value;
  256. }
  257. };
  258. coreml.Value = class {
  259. constructor(name, type, description, initializer) {
  260. if (typeof name !== 'string') {
  261. throw new coreml.Error("Invalid value identifier '" + JSON.stringify(name) + "'.");
  262. }
  263. this.name = name;
  264. this.type = type ? type : initializer ? initializer.type : null;
  265. this.description = description || null;
  266. this.initializer = initializer || null;
  267. this.quantization = initializer ? initializer.quantization : null;
  268. }
  269. };
  270. coreml.Node = class {
  271. constructor(context, obj) {
  272. if (!obj.type) {
  273. throw new Error('Undefined node type.');
  274. }
  275. if (obj.group) {
  276. this.group = obj.group || null;
  277. }
  278. this.type = Object.assign({}, context.metadata.type(obj.type) || { name: obj.type });
  279. this.type.name = obj.type.split(':').pop();
  280. this.name = obj.name || '';
  281. this.description = obj.description || '';
  282. this.inputs = (obj.inputs || []).map((argument) => {
  283. const values = argument.value.map((value) => value.obj);
  284. return new coreml.Argument(argument.name, argument.visible, values);
  285. });
  286. this.outputs = (obj.outputs || []).map((argument) => {
  287. const values = argument.value.map((value) => value.obj);
  288. return new coreml.Argument(argument.name, argument.visible, values);
  289. });
  290. this.attributes = Object.entries(obj.attributes).map(([name, value]) => {
  291. const metadata = context.metadata.attribute(obj.type, name);
  292. return new coreml.Attribute(metadata, name, value);
  293. });
  294. }
  295. };
  296. coreml.Attribute = class {
  297. constructor(metadata, name, value) {
  298. this.name = name;
  299. this.value = value;
  300. if (this.value instanceof coreml.Tensor) {
  301. this.type = 'tensor';
  302. }
  303. if (metadata) {
  304. if (metadata.type) {
  305. this.type = metadata.type;
  306. }
  307. if (this.type && coreml.proto) {
  308. this.value = coreml.Utility.enum(this.type, this.value);
  309. }
  310. if (metadata.visible === false) {
  311. this.visible = false;
  312. } else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  313. if (Array.isArray(value)) {
  314. value = value.map((item) => item.toNumber());
  315. }
  316. if (JSON.stringify(metadata.default) == JSON.stringify(value)) {
  317. this.visible = false;
  318. }
  319. }
  320. }
  321. if (this.value instanceof coreml.Graph) {
  322. this.type = 'graph';
  323. }
  324. }
  325. };
  326. coreml.Tensor = class {
  327. constructor(type, values, quantization, category) {
  328. this.type = type;
  329. this.encoding = type.dataType === 'float32' ? '|' : '<';
  330. this.values = values;
  331. this.category = category;
  332. this._quantization = quantization;
  333. }
  334. get quantization() {
  335. if (this._quantization) {
  336. if (this._quantization.lookupTableQuantization &&
  337. this._quantization.lookupTableQuantization.floatValue &&
  338. this._quantization.lookupTableQuantization.floatValue.length > 0) {
  339. const map = [];
  340. for (const key of Object.keys(this._quantization.lookupTableQuantization.floatValue)) {
  341. map.push(key.toString() + ' = ' + this._quantization.lookupTableQuantization.floatValue[key].toString());
  342. }
  343. return map.join('; ');
  344. }
  345. return '?';
  346. }
  347. return null;
  348. }
  349. };
  350. coreml.TensorType = class {
  351. constructor(dataType, shape) {
  352. this.dataType = dataType;
  353. this.shape = shape || new coreml.TensorShape([]);
  354. }
  355. equals(obj) {
  356. return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
  357. }
  358. toString() {
  359. return this.dataType + this.shape.toString();
  360. }
  361. };
  362. coreml.TensorShape = class {
  363. constructor(dimensions) {
  364. this.dimensions = dimensions.map((dim) => typeof dim === 'string' || Number.isInteger(dim) ? dim : dim.toNumber());
  365. }
  366. equals(obj) {
  367. return obj && Array.isArray(obj.dimensions) && Array.isArray(this.dimensions) &&
  368. this.dimensions.length === obj.dimensions.length &&
  369. obj.dimensions.every((value, index) => this.dimensions[index] === value);
  370. }
  371. toString() {
  372. return Array.isArray(this.dimensions) && this.dimensions.length > 0 ?
  373. '[' + this.dimensions.map((dimension) => dimension.toString()).join(',') + ']' : '';
  374. }
  375. };
  376. coreml.ListType = class {
  377. constructor(elementType) {
  378. this.elementType = elementType;
  379. }
  380. equals(obj) {
  381. return obj instanceof coreml.ListType && this.elementType.equals(obj.elementType);
  382. }
  383. toString() {
  384. return 'list<' + this.elementType.toString() + '>';
  385. }
  386. };
  387. coreml.MapType = class {
  388. constructor(keyType, valueType) {
  389. this.keyType = keyType;
  390. this.valueType = valueType;
  391. }
  392. toString() {
  393. return 'map<' + this.keyType + ',' + this.valueType.toString() + '>';
  394. }
  395. };
  396. coreml.SequenceType = class {
  397. constructor(type) {
  398. this.type = type;
  399. }
  400. toString() {
  401. return 'sequence<' + this.type + '>';
  402. }
  403. };
  404. coreml.ImageType = class {
  405. constructor(colorSpace, width, height) {
  406. this.width = width;
  407. this.height = height;
  408. switch (colorSpace) {
  409. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE:
  410. this.colorSpace = 'grayscale';
  411. break;
  412. case coreml.proto.ImageFeatureType.ColorSpace.RGB:
  413. this.colorSpace = 'RGB';
  414. break;
  415. case coreml.proto.ImageFeatureType.ColorSpace.BGR:
  416. this.colorSpace = 'BGR';
  417. break;
  418. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE_FLOAT16:
  419. this.colorSpace = 'grayscale:float16';
  420. break;
  421. default:
  422. throw new coreml.Error("Unsupported image color space '" + colorSpace + "'.");
  423. }
  424. }
  425. equals(obj) {
  426. return obj instanceof coreml.ImageType && this.width === obj.width && this.height === obj.height && this.colorSpace === obj.colorSpace;
  427. }
  428. toString() {
  429. return 'image<' + this.colorSpace + ',' + this.width. toString() + 'x' + this.height.toString() + '>';
  430. }
  431. };
  432. coreml.OptionalType = class {
  433. constructor(type) {
  434. this.type = type;
  435. }
  436. toString() {
  437. return 'optional<' + this.type.toString() + '>';
  438. }
  439. };
  440. coreml.Context = class {
  441. constructor(metadata, model, weights, values) {
  442. this.metadata = metadata;
  443. this.weights = weights;
  444. this.values = values || new Map();
  445. this.nodes = [];
  446. this.inputs = [];
  447. this.outputs = [];
  448. if (model) {
  449. const description = model.description;
  450. const inputs = description && Array.isArray(description.input) ? description.input : [];
  451. for (const description of inputs) {
  452. const value = this.output(description.name);
  453. this.update(value, description);
  454. this.inputs.push({ name: description.name, visible: true, value: [ value ] });
  455. }
  456. this.type = this.model(model, '', description);
  457. const outputs = description && Array.isArray(description.output) ? description.output : [];
  458. for (const description of outputs) {
  459. const value = this.input(description.name);
  460. this.update(value, description);
  461. this.outputs.push({ name: description.name, visible: true, value: [ value ] });
  462. }
  463. }
  464. }
  465. context() {
  466. return new coreml.Context(this.metadata, null, this.weights, this.values);
  467. }
  468. network(obj) {
  469. const context = this.context();
  470. for (const layer of obj.layers) {
  471. const type = layer.layer;
  472. context.node(context.groups, type, layer.name, '', layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  473. }
  474. context.updatePreprocessing('', obj.preprocessing, null);
  475. context.type = 'Neural Network';
  476. return context;
  477. }
  478. input(name) {
  479. if (!this.values.has(name)) {
  480. this.values.set(name, { counter: 0, name: name, to: [], from: [] });
  481. }
  482. return this.values.get(name);
  483. }
  484. output(name) {
  485. if (!this.values.has(name)) {
  486. const value = { counter: 0, name: name, to: [], from: [] };
  487. this.values.set(name, value);
  488. const key = name + '|' + value.counter.toString();
  489. this.values.set(key, value);
  490. } else {
  491. const value = Object.assign({}, this.values.get(name));
  492. value.counter++;
  493. value.name = name + '|' + value.counter.toString(); // custom argument id
  494. this.values.set(name, value);
  495. this.values.set(value.name, value);
  496. }
  497. return this.values.get(name);
  498. }
  499. update(value, description) {
  500. if (!value.type) {
  501. value.type = coreml.Utility.featureType(description.type);
  502. }
  503. if (!value.description && description.shortDescription) {
  504. value.description = description.shortDescription;
  505. }
  506. }
  507. node(group, type, name, description, data, inputs, outputs, inputTensors, outputTensors) {
  508. const obj = {
  509. group: group,
  510. type: type,
  511. name: name,
  512. description: description,
  513. attributes: {},
  514. inputs: [],
  515. outputs: []
  516. };
  517. inputs = inputs.map((input, index) => {
  518. const value = this.input(input);
  519. if (!value.type && inputTensors && index < inputTensors.length) {
  520. const tensor = inputTensors[index];
  521. const shape = tensor && tensor.dimValue ? new coreml.TensorShape(tensor.dimValue) : null;
  522. value.type = new coreml.TensorType('?', shape);
  523. }
  524. return value;
  525. });
  526. outputs = outputs.map((output, index) => {
  527. const value = this.output(output);
  528. if (!value.type && outputTensors && index < outputTensors.length) {
  529. const tensor = outputTensors[index];
  530. const shape = tensor && tensor.dimValue ? new coreml.TensorShape(tensor.dimValue) : null;
  531. value.type = new coreml.TensorType('?', shape);
  532. }
  533. return value;
  534. });
  535. const initializers = [];
  536. const initializer = (type, name, shape, data) => {
  537. let dataType = '?';
  538. let quantization = null;
  539. let values = null;
  540. if (data) {
  541. if (data.floatValue && data.floatValue.length > 0) {
  542. values = data.floatValue;
  543. dataType = 'float32';
  544. } else if (data.float16Value && data.float16Value.length > 0) {
  545. values = data.float16Value; // byte[]
  546. dataType = 'float16';
  547. } else if (data.rawValue && data.rawValue.length > 0) {
  548. if (data.quantization) {
  549. values = data.rawValue;
  550. dataType = 'uint' + data.quantization.numberOfBits.toString();
  551. } else {
  552. shape = [];
  553. }
  554. }
  555. quantization = data.quantization || null;
  556. }
  557. const tensorType = new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  558. const tensor = new coreml.Tensor(tensorType, values, quantization, 'Weights');
  559. const input = this.metadata.input(type, name);
  560. const visible = input && input.visible === false ? false : true;
  561. const value = { obj: new coreml.Value('', null, null, tensor) };
  562. initializers.push({ name: name, visible: visible, value: [ value ] });
  563. };
  564. const vector = (value) => {
  565. return (value && Object.keys(value).length == 1 && value.vector) ? value.vector : value;
  566. };
  567. const weights = (type, data) => {
  568. switch (type) {
  569. case 'convolution': {
  570. const weightsShape = [ data.outputChannels, data.kernelChannels, data.kernelSize[0], data.kernelSize[1] ];
  571. if (data.isDeconvolution) {
  572. weightsShape[0] = data.kernelChannels;
  573. weightsShape[1] = Math.floor(data.outputChannels / (data.nGroups != 0 ? data.nGroups : 1));
  574. }
  575. initializer(type, 'weights', weightsShape, data.weights);
  576. if (data.hasBias) {
  577. initializer(type, 'bias', [ data.outputChannels ], data.bias);
  578. }
  579. return { 'weights': true, 'bias': data.hasBias };
  580. }
  581. case 'innerProduct':
  582. initializer(type, 'weights', [ data.outputChannels, data.inputChannels ], data.weights);
  583. if (data.hasBias) {
  584. initializer(type, 'bias', [ data.outputChannels ], data.bias);
  585. }
  586. return { 'weights': true, 'bias': data.hasBias };
  587. case 'batchnorm':
  588. initializer(type, 'gamma', [ data.channels ], data.gamma);
  589. initializer(type, 'beta', [ data.channels ], data.beta);
  590. if (data.mean) {
  591. initializer(type, 'mean', [ data.channels ], data.mean);
  592. }
  593. if (data.variance) {
  594. initializer(type, 'variance', [ data.channels ], data.variance);
  595. }
  596. return { 'gamma': true, 'beta': true, 'mean': true, 'variance': true };
  597. case 'embedding':
  598. initializer(type, 'weights', [ data.inputDim, data.outputChannels ], data.weights);
  599. return { 'weights': true };
  600. case 'loadConstant':
  601. case 'loadConstantND':
  602. initializer(type, 'data', data.shape, data.data);
  603. return { 'data': true };
  604. case 'scale':
  605. initializer(type, 'scale', data.shapeScale, data.scale);
  606. if (data.hasBias) {
  607. initializer(type, 'bias', data.shapeBias, data.bias);
  608. }
  609. return { 'scale': true, 'bias': data.hasBias };
  610. case 'bias':
  611. initializer(type, 'bias', data.shape, data.bias);
  612. return { 'bias': true };
  613. case 'simpleRecurrent':
  614. initializer(type, 'weights', [ data.outputVectorSize, data.inputVectorSize ], data.weightMatrix);
  615. initializer(type, 'recurrent', [ data.outputVectorSize, data.inputVectorSize ], data.recursionMatrix);
  616. if (data.hasBiasVectors) {
  617. initializer(type, 'bias', [ data.outputVectorSize ], data.biasVector);
  618. }
  619. return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
  620. case 'gru': {
  621. const recursionMatrixShape = [ data.outputVectorSize, data.outputVectorSize ];
  622. const weightMatrixShape = [ data.outputVectorSize, data.inputVectorSize ];
  623. const biasVectorShape = [ data.outputVectorSize ];
  624. initializer(type, 'updateGateWeightMatrix', weightMatrixShape, data.updateGateWeightMatrix);
  625. initializer(type, 'resetGateWeightMatrix', weightMatrixShape, data.resetGateWeightMatrix);
  626. initializer(type, 'outputGateWeightMatrix', weightMatrixShape, data.outputGateWeightMatrix);
  627. initializer(type, 'updateGateRecursionMatrix', recursionMatrixShape, data.updateGateRecursionMatrix);
  628. initializer(type, 'resetGateRecursionMatrix', recursionMatrixShape, data.resetGateRecursionMatrix);
  629. initializer(type, 'outputGateRecursionMatrix', recursionMatrixShape, data.outputGateRecursionMatrix);
  630. if (data.hasBiasVectors) {
  631. initializer(type, 'updateGateBiasVector', biasVectorShape, data.updateGateBiasVector);
  632. initializer(type, 'resetGateBiasVector', biasVectorShape, data.resetGateBiasVector);
  633. initializer(type, 'outputGateBiasVector', biasVectorShape, data.outputGateBiasVector);
  634. }
  635. return {
  636. 'updateGateWeightMatrix': true, 'resetGateWeightMatrix': true, 'outputGateWeightMatrix': true,
  637. 'updateGateRecursionMatrix': true, 'resetGateRecursionMatrix': true, 'outputGateRecursionMatrix': true,
  638. 'updateGateBiasVector': data.hasBiasVectors, 'resetGateBiasVector': data.hasBiasVectors, 'outputGateBiasVector': data.hasBiasVectors
  639. };
  640. }
  641. case 'uniDirectionalLSTM':
  642. case 'biDirectionalLSTM': {
  643. const count = (type == 'uniDirectionalLSTM') ? 1 : 2;
  644. const h = data.outputVectorSize;
  645. const x = data.inputVectorSize;
  646. for (let i = 0; i < count; i++) {
  647. const weights = count == 1 ? data.weightParams : data.weightParams[i];
  648. const suffix = (i == 0) ? '' : '_rev';
  649. initializer(type, 'inputGateWeightMatrix' + suffix, [h,x], weights.inputGateWeightMatrix);
  650. initializer(type, 'forgetGateWeightMatrix' + suffix, [h,x], weights.forgetGateWeightMatrix);
  651. initializer(type, 'blockInputWeightMatrix' + suffix, [h,x], weights.blockInputWeightMatrix);
  652. initializer(type, 'outputGateWeightMatrix' + suffix, [h,x], weights.outputGateWeightMatrix);
  653. initializer(type, 'inputGateRecursionMatrix' + suffix, [h,h], weights.inputGateRecursionMatrix);
  654. initializer(type, 'forgetGateRecursionMatrix' + suffix, [h,h],weights.forgetGateRecursionMatrix);
  655. initializer(type, 'blockInputRecursionMatrix' + suffix, [h,h], weights.blockInputRecursionMatrix);
  656. initializer(type, 'outputGateRecursionMatrix' + suffix, [h,h], weights.outputGateRecursionMatrix);
  657. if (data.params.hasBiasVectors) {
  658. initializer(type, 'inputGateBiasVector' + suffix, [h], weights.inputGateBiasVector);
  659. initializer(type, 'forgetGateBiasVector' + suffix, [h], weights.forgetGateBiasVector);
  660. initializer(type, 'blockInputBiasVector' + suffix, [h], weights.blockInputBiasVector);
  661. initializer(type, 'outputGateBiasVector' + suffix, [h], weights.outputGateBiasVector);
  662. }
  663. if (data.params.hasPeepholeVectors) {
  664. initializer(type, 'inputGatePeepholeVector' + suffix, [h], weights.inputGatePeepholeVector);
  665. initializer(type, 'forgetGatePeepholeVector' + suffix, [h], weights.forgetGatePeepholeVector);
  666. initializer(type, 'outputGatePeepholeVector' + suffix, [h], weights.outputGatePeepholeVector);
  667. }
  668. }
  669. return { 'weightParams': true };
  670. }
  671. case 'dictVectorizer':
  672. data.stringToIndex = vector(data.stringToIndex);
  673. return {};
  674. case 'wordTagger':
  675. data.modelParameterData = Array.from(data.modelParameterData);
  676. data.stringTags = vector(data.stringTags);
  677. return { tokensOutputFeatureName: true, tokenTagsOutputFeatureName: true, tokenLengthsOutputFeatureName: true, tokenLocationsOutputFeatureName: true };
  678. case 'textClassifier':
  679. data.modelParameterData = Array.from(data.modelParameterData);
  680. data.stringClassLabels = vector(data.stringClassLabels);
  681. return {};
  682. case 'nonMaximumSuppression':
  683. data.stringClassLabels = vector(data.stringClassLabels);
  684. return {};
  685. default:
  686. return {};
  687. }
  688. };
  689. if (data) {
  690. const attributes = obj.attributes;
  691. const map = weights(type, data, initializers);
  692. for (const [name, value] of Object.entries(data)) {
  693. if (!map[name]) {
  694. attributes[name] = value;
  695. }
  696. }
  697. switch (obj.type) {
  698. case 'loop':
  699. attributes.bodyNetwork = this.network(attributes.bodyNetwork);
  700. attributes.conditionNetwork = this.network(attributes.conditionNetwork);
  701. break;
  702. case 'branch':
  703. attributes.ifBranch = this.network(attributes.ifBranch);
  704. attributes.elseBranch = this.network(attributes.elseBranch);
  705. break;
  706. default:
  707. break;
  708. }
  709. }
  710. const metadata = this.metadata.type(type);
  711. for (let i = 0; i < inputs.length;) {
  712. const input = metadata && metadata.inputs && i < metadata.inputs.length ? metadata.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
  713. const count = input.type === 'Tensor[]' ? inputs.length - i : 1;
  714. const values = inputs.slice(i, i + count);
  715. obj.inputs.push({ name: input.name, visible: true, value: values });
  716. i += count;
  717. }
  718. obj.inputs.push(...initializers);
  719. for (let i = 0; i < outputs.length;) {
  720. const output = metadata && metadata.outputs && i < metadata.outputs.length ? metadata.outputs[i] : { name: i === 0 ? 'output' : i.toString() };
  721. const count = output.type === 'Tensor[]' ? outputs.length - i : 1;
  722. const args = outputs.slice(i, i + count);
  723. obj.outputs.push({ name: output.name, visible: true, value: args });
  724. i += count;
  725. }
  726. this.nodes.push(obj);
  727. return obj;
  728. }
  729. model(model, group, description) {
  730. this.groups = this.groups | (group.length > 0 ? true : false);
  731. const shortDescription = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : '';
  732. switch (model.Type) {
  733. case 'neuralNetworkClassifier': {
  734. const neuralNetworkClassifier = model.neuralNetworkClassifier;
  735. for (const layer of neuralNetworkClassifier.layers) {
  736. const type = layer.layer;
  737. this.node(group, type, layer.name, group === '' ? '' : shortDescription, layer[type], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  738. }
  739. this.updateClassifierOutput(group, neuralNetworkClassifier, description);
  740. this.updatePreprocessing(group, neuralNetworkClassifier.preprocessing, description);
  741. return 'Neural Network Classifier';
  742. }
  743. case 'neuralNetwork': {
  744. const neuralNetwork = model.neuralNetwork;
  745. for (const layer of neuralNetwork.layers) {
  746. this.node(group, layer.layer, layer.name, group === '' ? '' : shortDescription, layer[layer.layer], layer.input, layer.output, layer.inputTensor, layer.outputTensor);
  747. }
  748. this.updatePreprocessing(group, neuralNetwork.preprocessing, description);
  749. return 'Neural Network';
  750. }
  751. case 'neuralNetworkRegressor': {
  752. const neuralNetworkRegressor = model.neuralNetworkRegressor;
  753. for (const layer of neuralNetworkRegressor.layers) {
  754. this.node(group, layer.layer, layer.name, shortDescription, layer[layer.layer], layer.input, layer.output);
  755. }
  756. this.updatePreprocessing(group, neuralNetworkRegressor, description);
  757. return 'Neural Network Regressor';
  758. }
  759. case 'pipeline': {
  760. for (let i = 0; i < model.pipeline.models.length; i++) {
  761. this.model(model.pipeline.models[i], (group ? (group + '/') : '') + 'pipeline[' + i.toString() + ']', description);
  762. }
  763. return 'Pipeline';
  764. }
  765. case 'pipelineClassifier': {
  766. for (let i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) {
  767. this.model(model.pipelineClassifier.pipeline.models[i], (group ? (group + '/') : '') + 'pipelineClassifier[' + i.toString() + ']', description);
  768. }
  769. return 'Pipeline Classifier';
  770. }
  771. case 'pipelineRegressor': {
  772. for (let i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) {
  773. this.model(model.pipelineRegressor.pipeline.models[i], (group ? (group + '/') : '') + 'pipelineRegressor[' + i.toString() + ']', description);
  774. }
  775. return 'Pipeline Regressor';
  776. }
  777. case 'glmClassifier': {
  778. this.node(group, 'glmClassifier', null, shortDescription,
  779. {
  780. classEncoding: model.glmClassifier.classEncoding,
  781. offset: model.glmClassifier.offset,
  782. weights: model.glmClassifier.weights
  783. },
  784. [ model.description.input[0].name ],
  785. [ model.description.output[0].name ]);
  786. this.updateClassifierOutput(group, model.glmClassifier, description);
  787. return 'Generalized Linear Classifier';
  788. }
  789. case 'glmRegressor': {
  790. this.node(group, 'glmRegressor', null, shortDescription,
  791. model.glmRegressor,
  792. [ model.description.input[0].name ],
  793. [ model.description.output[0].name ]);
  794. return 'Generalized Linear Regressor';
  795. }
  796. case 'treeEnsembleClassifier': {
  797. this.node(group, 'treeEnsembleClassifier', null, shortDescription,
  798. model.treeEnsembleClassifier.treeEnsemble,
  799. [ model.description.input[0].name ],
  800. [ model.description.output[0].name ]);
  801. this.updateClassifierOutput(group, model.treeEnsembleClassifier, description);
  802. return 'Tree Ensemble Classifier';
  803. }
  804. case 'treeEnsembleRegressor': {
  805. this.node(group, 'treeEnsembleRegressor', null, shortDescription,
  806. model.treeEnsembleRegressor.treeEnsemble,
  807. [ model.description.input[0].name ],
  808. [ model.description.output[0].name ]);
  809. return 'Tree Ensemble Regressor';
  810. }
  811. case 'supportVectorClassifier': {
  812. this.node(group, 'supportVectorClassifier', null, shortDescription,
  813. {
  814. coefficients: model.supportVectorClassifier.coefficients,
  815. denseSupportVectors: model.supportVectorClassifier.denseSupportVectors,
  816. kernel: model.supportVectorClassifier.kernel,
  817. numberOfSupportVectorsPerClass: model.supportVectorClassifier.numberOfSupportVectorsPerClass,
  818. probA: model.supportVectorClassifier.probA,
  819. probB: model.supportVectorClassifier.probB,
  820. rho: model.supportVectorClassifier.rho,
  821. supportVectors: model.supportVectorClassifier.supportVectors
  822. },
  823. [ model.description.input[0].name ],
  824. [ model.description.output[0].name ]);
  825. this.updateClassifierOutput(group, model.supportVectorClassifier, description);
  826. return 'Support Vector Classifier';
  827. }
  828. case 'supportVectorRegressor': {
  829. this.node(group, 'supportVectorRegressor', null, shortDescription,
  830. {
  831. coefficients: model.supportVectorRegressor.coefficients,
  832. kernel: model.supportVectorRegressor.kernel,
  833. rho: model.supportVectorRegressor.rho,
  834. supportVectors: model.supportVectorRegressor.supportVectors
  835. },
  836. [ model.description.input[0].name ],
  837. [ model.description.output[0].name ]);
  838. return 'Support Vector Regressor';
  839. }
  840. case 'oneHotEncoder': {
  841. const categoryType = model.oneHotEncoder.CategoryType;
  842. const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse };
  843. oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType];
  844. this.node(group, 'oneHotEncoder', null, shortDescription,
  845. oneHotEncoderParams,
  846. [ model.description.input[0].name ],
  847. [ model.description.output[0].name ]);
  848. return 'One Hot Encoder';
  849. }
  850. case 'imputer': {
  851. const imputedValue = model.imputer.ImputedValue;
  852. const replaceValue = model.imputer.ReplaceValue;
  853. const imputerParams = {};
  854. imputerParams[imputedValue] = model.imputer[imputedValue];
  855. imputerParams[replaceValue] = model.imputer[replaceValue];
  856. this.node(group, 'oneHotEncoder', null, shortDescription,
  857. imputerParams,
  858. [ model.description.input[0].name ],
  859. [ model.description.output[0].name ]);
  860. return 'Imputer';
  861. }
  862. case 'featureVectorizer': {
  863. this.node(group, 'featureVectorizer', null, shortDescription,
  864. model.featureVectorizer,
  865. model.description.input.map((item) => item.name),
  866. [ model.description.output[0].name ]);
  867. return 'Feature Vectorizer';
  868. }
  869. case 'dictVectorizer': {
  870. this.node(group, 'dictVectorizer', null, shortDescription,
  871. model.dictVectorizer,
  872. [ model.description.input[0].name ],
  873. [ model.description.output[0].name ]);
  874. return 'Dictionary Vectorizer';
  875. }
  876. case 'scaler': {
  877. this.node(group, 'scaler', null, shortDescription,
  878. model.scaler,
  879. [ model.description.input[0].name ],
  880. [ model.description.output[0].name ]);
  881. return 'Scaler';
  882. }
  883. case 'categoricalMapping': {
  884. this.node(group, 'categoricalMapping', null, shortDescription,
  885. model.categoricalMapping,
  886. [ model.description.input[0].name ],
  887. [ model.description.output[0].name ]);
  888. return 'Categorical Mapping';
  889. }
  890. case 'normalizer': {
  891. this.node(group, 'normalizer', null, shortDescription,
  892. model.normalizer,
  893. [ model.description.input[0].name ],
  894. [ model.description.output[0].name ]);
  895. return 'Normalizer';
  896. }
  897. case 'arrayFeatureExtractor': {
  898. this.node(group, 'arrayFeatureExtractor', null, shortDescription,
  899. { extractIndex: model.arrayFeatureExtractor.extractIndex },
  900. [ model.description.input[0].name ],
  901. [ model.description.output[0].name ]);
  902. return 'Array Feature Extractor';
  903. }
  904. case 'nonMaximumSuppression': {
  905. const nonMaximumSuppressionParams = {
  906. pickTop: model.nonMaximumSuppression.pickTop,
  907. stringClassLabels: model.nonMaximumSuppression.stringClassLabels,
  908. iouThreshold: model.nonMaximumSuppression.iouThreshold,
  909. confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold
  910. };
  911. this.node(group, 'nonMaximumSuppression', null, shortDescription,
  912. nonMaximumSuppressionParams,
  913. [
  914. model.nonMaximumSuppression.confidenceInputFeatureName,
  915. model.nonMaximumSuppression.coordinatesInputFeatureName,
  916. model.nonMaximumSuppression.iouThresholdInputFeatureName,
  917. model.nonMaximumSuppression.confidenceThresholdInputFeatureName,
  918. ],
  919. [
  920. model.nonMaximumSuppression.confidenceOutputFeatureName,
  921. model.nonMaximumSuppression.coordinatesOutputFeatureName
  922. ]);
  923. return 'Non Maximum Suppression';
  924. }
  925. case 'wordTagger': {
  926. this.node(group, 'wordTagger', null, shortDescription,
  927. model.wordTagger,
  928. [ model.description.input[0].name ],
  929. [
  930. model.wordTagger.tokensOutputFeatureName,
  931. model.wordTagger.tokenTagsOutputFeatureName,
  932. model.wordTagger.tokenLocationsOutputFeatureName,
  933. model.wordTagger.tokenLengthsOutputFeatureName
  934. ]);
  935. return 'Word Tagger';
  936. }
  937. case 'textClassifier': {
  938. this.node(group, 'textClassifier', null, shortDescription,
  939. model.textClassifier,
  940. [ model.description.input[0].name ],
  941. [ model.description.output[0].name ]);
  942. return 'Text Classifier';
  943. }
  944. case 'visionFeaturePrint': {
  945. const visionFeaturePrintParams = {
  946. scene: model.visionFeaturePrint.scene
  947. };
  948. this.node(group, 'visionFeaturePrint', null, shortDescription,
  949. visionFeaturePrintParams,
  950. [ model.description.input[0].name ],
  951. [ model.description.output[0].name ]);
  952. return 'Vision Feature Print';
  953. }
  954. case 'soundAnalysisPreprocessing': {
  955. this.node(group, 'soundAnalysisPreprocessing', null, shortDescription,
  956. model.soundAnalysisPreprocessing,
  957. [ model.description.input[0].name ],
  958. [ model.description.output[0].name ]);
  959. return 'Sound Analysis Preprocessing';
  960. }
  961. case 'kNearestNeighborsClassifier': {
  962. this.node(group, 'kNearestNeighborsClassifier', null, shortDescription,
  963. model.kNearestNeighborsClassifier,
  964. [ model.description.input[0].name ],
  965. [ model.description.output[0].name ]);
  966. this.updateClassifierOutput(group, model.kNearestNeighborsClassifier, description);
  967. return 'Nearest Neighbors Classifier';
  968. }
  969. case 'itemSimilarityRecommender': {
  970. this.node(group, 'itemSimilarityRecommender', null, shortDescription,
  971. {
  972. itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector,
  973. itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities
  974. },
  975. model.description.input.map((feature) => feature.name),
  976. model.description.output.map((feature) => feature.name));
  977. return 'Item Similarity Recommender';
  978. }
  979. case 'audioFeaturePrint': {
  980. this.node(group, 'audioFeaturePrint', null, shortDescription,
  981. model.audioFeaturePrint,
  982. [ model.description.input[0].name ],
  983. [ model.description.output[0].name ]);
  984. return 'Audio Feature Print';
  985. }
  986. case 'linkedModel': {
  987. this.node(group, 'linkedModel', null, shortDescription,
  988. model.linkedModel.linkedModelFile,
  989. [ model.description.input[0].name ],
  990. [ model.description.output[0].name ]);
  991. return 'Linked Model';
  992. }
  993. case 'customModel': {
  994. this.node(group, 'customModel', null, shortDescription,
  995. { className: model.customModel.className, parameters: model.customModel.parameters },
  996. [ model.description.input[0].name ],
  997. [ model.description.output[0].name ]);
  998. return 'customModel';
  999. }
  1000. case 'mlProgram': {
  1001. return this.program(model.mlProgram, group);
  1002. }
  1003. default: {
  1004. throw new coreml.Error("Unsupported model type '" + JSON.stringify(Object.keys(model)) + "'.");
  1005. }
  1006. }
  1007. }
  1008. updateClassifierOutput(group, classifier, description) {
  1009. let labelProbabilityLayerName = classifier.labelProbabilityLayerName;
  1010. if (!labelProbabilityLayerName && this.nodes.length > 0) {
  1011. const node = this.nodes.slice(-1).pop();
  1012. if (node && node.outputs.length == 1 && node.outputs[0].value.length == 1) {
  1013. labelProbabilityLayerName = node.outputs[0].value[0].name;
  1014. }
  1015. }
  1016. let predictedFeatureName = description.predictedFeatureName;
  1017. let predictedProbabilitiesName = description.predictedProbabilitiesName;
  1018. if ((predictedFeatureName || predictedProbabilitiesName) && labelProbabilityLayerName && classifier.ClassLabels) {
  1019. predictedFeatureName = predictedFeatureName ? predictedFeatureName : '?';
  1020. predictedProbabilitiesName = predictedProbabilitiesName ? predictedProbabilitiesName : '?';
  1021. const labelProbabilityInput = labelProbabilityLayerName + ':labelProbabilityLayerName';
  1022. const values = new Set();
  1023. for (const node of this.nodes) {
  1024. for (const output of node.outputs) {
  1025. for (const value of output.value) {
  1026. if (value.name === labelProbabilityLayerName) {
  1027. value.name = labelProbabilityInput;
  1028. values.add(value);
  1029. }
  1030. }
  1031. }
  1032. }
  1033. this.values.set(labelProbabilityInput, this.values.get(labelProbabilityLayerName));
  1034. this.values.delete(labelProbabilityLayerName);
  1035. const type = classifier.ClassLabels;
  1036. const node = {
  1037. // group: this._group,
  1038. type: type,
  1039. name: null,
  1040. description: '',
  1041. attributes: classifier[type] || {}
  1042. };
  1043. node.inputs = [
  1044. { name: 'input', visible: true, value: Array.from(values) }
  1045. ];
  1046. node.outputs = [
  1047. { name: 'probabilities', visible: true, value: [ this.output(predictedProbabilitiesName) ] },
  1048. { name: 'feature', visible: true, value: [ this.output(predictedFeatureName) ] }
  1049. ];
  1050. this.nodes.push(node);
  1051. }
  1052. }
  1053. updatePreprocessing(group, preprocessings, description) {
  1054. if (preprocessings && preprocessings.length > 0) {
  1055. const preprocessingInput = description.input[0].name;
  1056. const inputNodes = [];
  1057. for (const node of this.nodes) {
  1058. if (node.inputs.some((input) => Array.isArray(input.value) && input.value.some((arg) => arg.name === preprocessingInput))) {
  1059. inputNodes.push(node);
  1060. }
  1061. }
  1062. let currentOutput = preprocessingInput;
  1063. let preprocessorOutput = null;
  1064. let preprocessorIndex = 0;
  1065. for (const preprocessing of preprocessings) {
  1066. const input = preprocessing.featureName ? preprocessing.featureName : currentOutput;
  1067. currentOutput = preprocessingInput + ':' + preprocessorIndex.toString();
  1068. const preprocessor = preprocessing.preprocessor;
  1069. const node = this.node(group, preprocessor, null, '', preprocessing[preprocessor], [ input ], [ currentOutput ]);
  1070. /* eslint-disable prefer-destructuring */
  1071. preprocessorOutput = node.outputs[0].value[0];
  1072. /* eslint-enable prefer-destructuring */
  1073. preprocessorIndex++;
  1074. }
  1075. for (const node of inputNodes) {
  1076. for (const input of node.inputs) {
  1077. if (Array.isArray(input.value)) {
  1078. for (let i = 0; i < input.value.length; i++) {
  1079. if (input.value[i].name === preprocessingInput) {
  1080. input.value[i] = preprocessorOutput;
  1081. }
  1082. }
  1083. }
  1084. }
  1085. }
  1086. }
  1087. }
  1088. program(program, group) {
  1089. // TODO: need to handle functions other than main?
  1090. const main = program.functions.main;
  1091. // TODO: need to handle more than one block specialization?
  1092. const block_specializations = main.block_specializations;
  1093. const key = Object.keys(block_specializations).filter((key) => key.startsWith('CoreML')).shift();
  1094. const block = block_specializations[key];
  1095. const convertValue = (value) => {
  1096. switch (value.value) {
  1097. case 'immediateValue': {
  1098. const tensor = value.immediateValue.tensor;
  1099. const type = coreml.Utility.valueType(value.type);
  1100. let values = null;
  1101. switch (tensor.value) {
  1102. case 'ints':
  1103. values = tensor.ints.values;
  1104. break;
  1105. case 'strings':
  1106. values = tensor.strings.values;
  1107. break;
  1108. case 'bools':
  1109. values = tensor.bools.values;
  1110. break;
  1111. case 'floats':
  1112. values = tensor.floats.values;
  1113. break;
  1114. case 'bytes':
  1115. values = tensor.bytes.values;
  1116. break;
  1117. default:
  1118. throw new coreml.Error("Unsupported tensor value '" + tensor.value + "'.");
  1119. }
  1120. if (type.shape.dimensions.length === 0) {
  1121. [values] = values;
  1122. }
  1123. return values;
  1124. }
  1125. case 'blobFileValue': {
  1126. const type = coreml.Utility.valueType(value.type);
  1127. const blob = value.blobFileValue;
  1128. const offset = blob.offset.toNumber();
  1129. const file = blob.fileName;
  1130. let data = null;
  1131. const stream = this.weights.get(file);
  1132. if (stream) {
  1133. stream.seek(offset);
  1134. const buffer = stream.read(32);
  1135. const reader = new base.BinaryReader(buffer);
  1136. const signature = reader.uint32();
  1137. if (signature == 0xdeadbeef) {
  1138. reader.uint32(); // dataType
  1139. const size = reader.uint64();
  1140. stream.seek(reader.uint64());
  1141. const length = (type.shape.dimensions || []).reduce((a, b) => a * b, 1);
  1142. switch (type.dataType) {
  1143. case 'float32': {
  1144. const buffer = stream.read(size);
  1145. data = new Float32Array(buffer.buffer, buffer.byteOffset, length).slice();
  1146. break;
  1147. }
  1148. case 'float16':
  1149. case 'int8':
  1150. case 'uint8': {
  1151. data = stream.read(size);
  1152. break;
  1153. }
  1154. default:
  1155. throw new coreml.Error("Unsupported blob data type '" + type.dataType + "'.");
  1156. }
  1157. }
  1158. }
  1159. return new coreml.Tensor(type, data, null, 'Blob');
  1160. }
  1161. default: {
  1162. throw new coreml.Error("Unsupported value '" + value.value + "'.");
  1163. }
  1164. }
  1165. };
  1166. const operations = block.operations.map((op) => {
  1167. const operation = {
  1168. type: op.type,
  1169. attributes: {}
  1170. };
  1171. for (const [key, value] of Object.entries(op.attributes)) {
  1172. operation.attributes[key] = convertValue(value);
  1173. }
  1174. operation.inputs = Object.entries(op.inputs).map(([name, input]) => {
  1175. const args = input.arguments.map((argument) => {
  1176. if (argument.name) {
  1177. const value = this.input(argument.name);
  1178. value.to.push(operation);
  1179. return value;
  1180. }
  1181. return { value: argument.value };
  1182. });
  1183. return { name: name, value: args };
  1184. });
  1185. operation.outputs = op.outputs.map((output) => {
  1186. const value = this.input(output.name);
  1187. value.type = coreml.Utility.valueType(output.type);
  1188. value.from.push(operation);
  1189. return { name: 'output', value: [ value ] };
  1190. });
  1191. return operation;
  1192. });
  1193. for (const op of operations) {
  1194. if (op.type === 'const' && op.inputs.length === 0 &&
  1195. op.outputs.length === 1 && op.outputs[0].value.length === 1) {
  1196. /* eslint-disable prefer-destructuring */
  1197. const value = op.outputs[0].value[0];
  1198. /* eslint-enable prefer-destructuring */
  1199. if (op.attributes && op.attributes.val) {
  1200. const type = value.type;
  1201. const data = op.attributes.val;
  1202. if (data instanceof Uint8Array && data.length === 2 &&
  1203. type.dataType === 'float16' && type.shape.dimensions.length === 0) {
  1204. const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
  1205. value.value = view.getFloat16(0, true);
  1206. } else {
  1207. value.value = data;
  1208. }
  1209. value.const = true;
  1210. op.delete = true;
  1211. }
  1212. }
  1213. }
  1214. for (const op of operations) {
  1215. for (const input of op.inputs) {
  1216. if (input.value.length > 1 && input.value.some((argument) => argument.const)) {
  1217. if (!input.value.every((argument) => argument.value instanceof coreml.Tensor)) {
  1218. for (const value of input.value) {
  1219. for (const from of value.from) {
  1220. from.delete = false;
  1221. }
  1222. delete value.value;
  1223. }
  1224. }
  1225. }
  1226. }
  1227. }
  1228. for (const op of operations.filter((op) => !op.delete)) {
  1229. op.inputs = op.inputs.filter((input) => {
  1230. if (input.value.every((value) => value.value === undefined || value.value instanceof coreml.Tensor)) {
  1231. return true;
  1232. }
  1233. op.attributes[input.name] = input.value.length === 1 ?
  1234. input.value[0].value :
  1235. input.value.map((argument) => argument.value[0]);
  1236. return false;
  1237. });
  1238. }
  1239. const mapValue = (name, value) => {
  1240. if (value.value instanceof coreml.Tensor) {
  1241. value.initializer = value.value;
  1242. delete value.value;
  1243. }
  1244. if (!this.values.has(name)) {
  1245. this.values.set(name, value);
  1246. } else if ((value.type && !value.type.equals(this.values.get(name).type)) ||
  1247. (value.initializer && value.initializer !== this.values.get(name).initializer)) {
  1248. throw new coreml.Error("Duplicate value '" + name + "'.");
  1249. }
  1250. return this.values.get(name);
  1251. };
  1252. for (const op of operations.filter((op) => !op.delete)) {
  1253. for (const argument of op.inputs) {
  1254. for (const value of argument.value) {
  1255. mapValue(value.name, value);
  1256. }
  1257. }
  1258. for (const argument of op.outputs) {
  1259. for (const value of argument.value) {
  1260. mapValue(value.name, value);
  1261. }
  1262. }
  1263. }
  1264. for (const op of operations.filter((op) => !op.delete)) {
  1265. op.group = group;
  1266. op.type = 'program:' + op.type;
  1267. const metadata = this.metadata.type(op.type);
  1268. if (metadata && Array.isArray(metadata.inputs)) {
  1269. const map = new Map(metadata.inputs.map((input, index) => [ input.name, index + 1 ]));
  1270. op.inputs.sort((a, b) => (map.get(a.name) || map.size) - (map.get(b.name) || map.size));
  1271. }
  1272. this.nodes.push(op);
  1273. }
  1274. return 'ML Program';
  1275. }
  1276. };
  1277. coreml.Utility = class {
  1278. static enum(name, value) {
  1279. let type = coreml.proto;
  1280. const parts = name.split('.');
  1281. while (type && parts.length > 0) {
  1282. type = type[parts.shift()];
  1283. }
  1284. if (type) {
  1285. coreml.Utility._enumKeyMap = coreml.Utility._enumKeyMap || new Map();
  1286. if (!coreml.Utility._enumKeyMap.has(name)) {
  1287. const map = new Map(Object.entries(type).map(([key, value]) => [ value, key ]));
  1288. coreml.Utility._enumKeyMap.set(name, map);
  1289. }
  1290. const map = coreml.Utility._enumKeyMap.get(name);
  1291. if (map.has(value)) {
  1292. return map.get(value);
  1293. }
  1294. }
  1295. return value;
  1296. }
  1297. static featureType(type) {
  1298. let result = '?';
  1299. if (type) {
  1300. switch (type.Type) {
  1301. case 'multiArrayType': {
  1302. let shape = new coreml.TensorShape([]);
  1303. if (type.multiArrayType.shape && type.multiArrayType.shape.length > 0) {
  1304. shape = new coreml.TensorShape(type.multiArrayType.shape.map((dim) => dim.toNumber()));
  1305. }
  1306. let dataType;
  1307. const ArrayDataType = coreml.proto.ArrayFeatureType.ArrayDataType;
  1308. switch (type.multiArrayType.dataType) {
  1309. case ArrayDataType.INVALID_ARRAY_DATA_TYPE:
  1310. dataType = '?';
  1311. break;
  1312. case ArrayDataType.FLOAT16:
  1313. dataType = 'float16';
  1314. break;
  1315. case ArrayDataType.FLOAT32:
  1316. dataType = 'float32';
  1317. break;
  1318. case ArrayDataType.DOUBLE:
  1319. dataType = 'float64';
  1320. break;
  1321. case ArrayDataType.INT32:
  1322. dataType = 'int32';
  1323. break;
  1324. default:
  1325. throw new coreml.Error("Unsupported array data type '" + type.multiArrayType.dataType + "'.");
  1326. }
  1327. result = new coreml.TensorType(dataType, shape);
  1328. break;
  1329. }
  1330. case 'stringType': {
  1331. result = new coreml.TensorType('string');
  1332. break;
  1333. }
  1334. case 'doubleType': {
  1335. result = new coreml.TensorType('float64');
  1336. break;
  1337. }
  1338. case 'int64Type': {
  1339. result = new coreml.TensorType('int64');
  1340. break;
  1341. }
  1342. case 'dictionaryType': {
  1343. result = new coreml.MapType(type.dictionaryType.KeyType.replace('KeyType', ''), 'float64');
  1344. break;
  1345. }
  1346. case 'sequenceType': {
  1347. result = new coreml.SequenceType(coreml.Utility.featureType(type[type.Type]));
  1348. break;
  1349. }
  1350. case 'imageType': {
  1351. result = new coreml.ImageType(type.imageType.colorSpace, type.imageType.width, type.imageType.height);
  1352. break;
  1353. }
  1354. default: {
  1355. throw new coreml.Error("Unsupported feature type '" + type.Type + "'.");
  1356. }
  1357. }
  1358. if (type.isOptional) {
  1359. result = new coreml.OptionalType(result);
  1360. }
  1361. }
  1362. return result;
  1363. }
  1364. static tensorType(type) {
  1365. if (!coreml.Utility._dataTypes) {
  1366. coreml.Utility._dataTypes = new Map(Object.entries(coreml.proto.MILSpec.DataType).map((([key, value]) => [value, key.toLowerCase()])));
  1367. coreml.Utility._dataTypes.delete(0);
  1368. coreml.Utility._dataTypes.set(1, 'boolean');
  1369. }
  1370. const shape = type.dimensions.map((dim) => dim.constant ? dim.constant.size : '?');
  1371. const dataType = coreml.Utility._dataTypes.get(type.dataType);
  1372. if (!dataType) {
  1373. throw new coreml.Error("Unsupported data type '" + type.dataType + "'.");
  1374. }
  1375. return new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  1376. }
  1377. static valueType(type) {
  1378. switch (type.type) {
  1379. case 'tensorType':
  1380. return coreml.Utility.tensorType(type.tensorType);
  1381. case 'listType':
  1382. return new coreml.ListType(coreml.Utility.valueType(type.listType.type));
  1383. case 'dictionaryType':
  1384. return new coreml.MapType(coreml.Utility.valueType(type.dictionaryType.keyType), coreml.Utility.valueType(type.dictionaryType.valueType));
  1385. default:
  1386. throw new coreml.Error("Unsupported value type '" + type.type + "'.");
  1387. }
  1388. }
  1389. };
  1390. coreml.Error = class extends Error {
  1391. constructor(message) {
  1392. super(message);
  1393. this.name = 'Error loading Core ML model.';
  1394. }
  1395. };
  1396. export const ModelFactory = coreml.ModelFactory;