2
0

coreml.js 68 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685
  1. /* jshint esversion: 6 */
  2. var coreml = coreml || {};
  3. var json = json || require('./json');
  4. var protobuf = protobuf || require('./protobuf');
  5. coreml.ModelFactory = class {
  6. match(context) {
  7. const stream = context.stream;
  8. const identifier = context.identifier.toLowerCase();
  9. const extension = identifier.split('.').pop().toLowerCase();
  10. const tags = context.tags('pb');
  11. if (tags.get(1) === 0 && tags.get(2) === 2) {
  12. if (extension === 'pb') {
  13. const tags = context.tags('pb+');
  14. const keys = Object.keys(tags).map((key) => parseInt(key, 10));
  15. const match = (key) =>
  16. (key >= 200 && key < 220) ||
  17. (key >= 300 && key < 320) ||
  18. (key >= 400 && key < 420) ||
  19. (key >= 500 && key < 520) ||
  20. (key >= 550 && key < 560) ||
  21. (key >= 600 && key < 620) ||
  22. (key === 900) ||
  23. (key >= 2000 && key < 2010) ||
  24. (key === 3000);
  25. if (!keys.some((key) => match(key))) {
  26. return null;
  27. }
  28. }
  29. return 'coreml.pb';
  30. }
  31. switch (identifier) {
  32. case 'manifest.json': {
  33. const obj = context.open('json');
  34. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  35. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  36. if (entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel').length === 1)){
  37. return 'coreml.manifest';
  38. }
  39. }
  40. break;
  41. }
  42. case 'metadata.json': {
  43. const obj = context.open('json');
  44. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  45. return 'coreml.metadata';
  46. }
  47. break;
  48. }
  49. case 'featuredescriptions.json': {
  50. const obj = context.open('json');
  51. if (obj && (obj.Inputs || obj.Outputs)) {
  52. return 'coreml.featuredescriptions';
  53. }
  54. break;
  55. }
  56. }
  57. if (extension === 'bin' && stream.length > 16) {
  58. const buffer = stream.peek(Math.min(256, stream.length));
  59. for (let i = 0; i < buffer.length - 4; i++) {
  60. const signature = (buffer[i] | buffer[i + 1] << 8 | buffer[i + 2] << 16 | buffer [i + 3] << 24) >>> 0;
  61. if (signature === 0xdeadbeef) {
  62. return 'coreml.weights';
  63. }
  64. }
  65. }
  66. return undefined;
  67. }
  68. open(context, match) {
  69. return context.require('./coreml-proto').then(() => {
  70. return coreml.Metadata.open(context).then((metadata) => {
  71. const openModel = (stream, context, path, format) => {
  72. let model = null;
  73. try {
  74. coreml.proto = protobuf.get('coreml').CoreML.Specification;
  75. const reader = protobuf.BinaryReader.open(stream);
  76. model = coreml.proto.Model.decode(reader);
  77. }
  78. catch (error) {
  79. const message = error && error.message ? error.message : error.toString();
  80. throw new coreml.Error('File format is not coreml.Model (' + message.replace(/\.$/, '') + ').');
  81. }
  82. const weightPaths = new Set();
  83. const walkProgram = (program) => {
  84. for (const entry of Object.entries(program.functions)) {
  85. const func = entry[1];
  86. for (const entry of Object.entries(func.block_specializations)) {
  87. const block = entry[1];
  88. for (const operation of block.operations) {
  89. for (const entry of Object.entries(operation.attributes)) {
  90. const value = entry[1];
  91. if (value.blobFileValue && value.blobFileValue.fileName) {
  92. weightPaths.add(value.blobFileValue.fileName);
  93. }
  94. }
  95. }
  96. }
  97. }
  98. };
  99. const walkModel = (model) => {
  100. if (model.mlProgram) {
  101. walkProgram(model.mlProgram);
  102. }
  103. if (model.pipeline && model.pipeline.models) {
  104. for (const node of model.pipeline.models) {
  105. walkModel(node);
  106. }
  107. }
  108. if (model.pipelineClassifier && model.pipelineClassifier.pipeline && model.pipelineClassifier.pipeline.models) {
  109. for (const node of model.pipelineClassifier.pipeline.models) {
  110. walkModel(node);
  111. }
  112. }
  113. if (model.pipelineRegressor && model.pipelineRegressor.pipeline && model.pipelineRegressor.pipeline.models) {
  114. for (const node of model.pipelineRegressor.pipeline.models) {
  115. walkModel(node);
  116. }
  117. }
  118. };
  119. walkModel(model);
  120. if (weightPaths.size > 0) {
  121. const items = path.split('/');
  122. items.pop();
  123. const folder = items.join('/');
  124. const keys = Array.from(weightPaths);
  125. const paths = keys.map((path) => {
  126. const items = path.split('/');
  127. if (items[0] === '@model_path') {
  128. items[0] = folder;
  129. }
  130. return items.join('/');
  131. });
  132. const promises = paths.map((path) => context.request(path, null));
  133. return Promise.all(promises).then((streams) => {
  134. const weights = new Map();
  135. for (let i = 0; i < keys.length; i++) {
  136. weights.set(keys[i], streams[i]);
  137. }
  138. return new coreml.Model(metadata, format, model, weights);
  139. }).catch((/* err */) => {
  140. return new coreml.Model(metadata, format, model, new Map());
  141. });
  142. }
  143. return new coreml.Model(metadata, format, model, new Map());
  144. };
  145. const openManifest = (obj, context, path) => {
  146. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  147. const entry = entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel'))[0];
  148. const file = path + 'Data/' + entry.path;
  149. return context.request(file, null).then((stream) => {
  150. return openModel(stream, context, file, 'Core ML Package');
  151. });
  152. };
  153. const openManifestStream = (context, path) => {
  154. return context.request(path + 'Manifest.json', null).then((stream) => {
  155. const reader = json.TextReader.open(stream);
  156. const obj = reader.read();
  157. return openManifest(obj, context, path);
  158. });
  159. };
  160. switch (match) {
  161. case 'coreml.pb': {
  162. return openModel(context.stream, context, context.identifier);
  163. }
  164. case 'coreml.manifest': {
  165. const obj = context.open('json');
  166. return openManifest(obj, context, '');
  167. }
  168. case 'coreml.featuredescriptions':
  169. case 'coreml.metadata': {
  170. return openManifestStream(context, '../../');
  171. }
  172. case 'coreml.weights': {
  173. return openManifestStream(context, '../../../');
  174. }
  175. default: {
  176. throw new coreml.Error("Unknown Core ML format '" + match + "'.");
  177. }
  178. }
  179. });
  180. });
  181. }
  182. };
  183. coreml.Model = class {
  184. constructor(metadata, format, model, weights) {
  185. this._format = (format || 'Core ML') + ' v' + model.specificationVersion.toString();
  186. this._graphs = [ new coreml.Graph(metadata, model, weights) ];
  187. if (model.description && model.description.metadata) {
  188. const properties = model.description.metadata;
  189. if (properties.versionString) {
  190. this._version = properties.versionString;
  191. }
  192. if (properties.author) {
  193. this._author = properties.author;
  194. }
  195. if (properties.shortDescription) {
  196. this._description = properties.shortDescription;
  197. }
  198. if (properties.license) {
  199. this._license = properties.license;
  200. }
  201. if (metadata.userDefined && Object.keys(properties.userDefined).length > 0) {
  202. /* empty */
  203. }
  204. }
  205. }
  206. get format() {
  207. return this._format;
  208. }
  209. get version() {
  210. return this._version || null;
  211. }
  212. get description() {
  213. return this._description || null;
  214. }
  215. get author() {
  216. return this._author || null;
  217. }
  218. get license() {
  219. return this._license || null;
  220. }
  221. get graphs() {
  222. return this._graphs;
  223. }
  224. };
  225. coreml.Graph = class {
  226. constructor(metadata, model, weights) {
  227. this._metadata = metadata;
  228. this._description = model.description;
  229. this._groups = false;
  230. this._inputs = [];
  231. this._outputs = [];
  232. this._nodes = [];
  233. if (this._description) {
  234. this._inputs = this._description.input.map((input) => {
  235. const argument = new coreml.Argument(input.name, coreml.Utility.featureType(input.type), input.shortDescription, null);
  236. return new coreml.Parameter(input.name, true, [ argument ]);
  237. });
  238. this._outputs = this._description.output.map((output) => {
  239. const argument = new coreml.Argument(output.name, coreml.Utility.featureType(output.type), output.shortDescription, null);
  240. return new coreml.Parameter(output.name, true, [ argument ]);
  241. });
  242. }
  243. this._type = this._loadModel(model, {}, '', weights);
  244. }
  245. get name() {
  246. return '';
  247. }
  248. get type() {
  249. return this._type;
  250. }
  251. get inputs() {
  252. return this._inputs;
  253. }
  254. get outputs() {
  255. return this._outputs;
  256. }
  257. get nodes() {
  258. return this._nodes;
  259. }
  260. get groups() {
  261. return this._groups;
  262. }
  263. _updateOutput(name, newName) {
  264. for (const node of this._nodes) {
  265. for (const output of node.outputs) {
  266. for (const argument of output.arguments) {
  267. if (argument.name === name) {
  268. argument.name = newName;
  269. }
  270. }
  271. }
  272. }
  273. return newName;
  274. }
  275. _updateClassifierOutput(group, classifier) {
  276. let labelProbabilityLayerName = classifier.labelProbabilityLayerName;
  277. if (!labelProbabilityLayerName && this._nodes.length > 0) {
  278. const node = this._nodes.slice(-1).pop();
  279. if (node && node.outputs.length == 1 && node.outputs[0].arguments.length == 1) {
  280. labelProbabilityLayerName = node.outputs[0].arguments[0].name;
  281. }
  282. }
  283. let predictedFeatureName = this._description.predictedFeatureName;
  284. let predictedProbabilitiesName = this._description.predictedProbabilitiesName;
  285. if ((predictedFeatureName || predictedProbabilitiesName) && labelProbabilityLayerName && classifier.ClassLabels) {
  286. predictedFeatureName = predictedFeatureName ? predictedFeatureName : '?';
  287. predictedProbabilitiesName = predictedProbabilitiesName ? predictedProbabilitiesName : '?';
  288. const labelProbabilityInput = this._updateOutput(labelProbabilityLayerName, labelProbabilityLayerName + ':labelProbabilityLayerName');
  289. const type = classifier.ClassLabels;
  290. const inputs = [
  291. new coreml.Parameter('input', true, [ new coreml.Argument(labelProbabilityInput) ])
  292. ];
  293. const outputs = [
  294. new coreml.Parameter('probabilities', true, [ new coreml.Argument(predictedProbabilitiesName) ]),
  295. new coreml.Parameter('feature', true, [ new coreml.Argument(predictedFeatureName) ])
  296. ];
  297. const node = new coreml.Node(this._metadata, this._group, type, null, '', classifier[type], inputs, outputs);
  298. this._nodes.push(node);
  299. }
  300. }
  301. _updatePreprocessing(scope, group, preprocessing) {
  302. if (preprocessing && preprocessing.length > 0) {
  303. const preprocessingInput = this._description.input[0].name;
  304. const inputNodes = [];
  305. for (const node of this._nodes) {
  306. if (node.inputs.some((input) => input.arguments.some((arg) => arg.name == preprocessingInput))) {
  307. inputNodes.push(node);
  308. }
  309. }
  310. let preprocessorOutput = preprocessingInput;
  311. let preprocessorIndex = 0;
  312. for (const p of preprocessing) {
  313. const input = p.featureName ? p.featureName : preprocessorOutput;
  314. preprocessorOutput = preprocessingInput + ':' + preprocessorIndex.toString();
  315. this._createNode(scope, group, p.preprocessor, null, '', p[p.preprocessor], [ input ], [ preprocessorOutput ]);
  316. preprocessorIndex++;
  317. }
  318. for (const node of inputNodes) {
  319. for (const input of node.inputs) {
  320. for (const arg of input.arguments) {
  321. if (arg.name === preprocessingInput) {
  322. arg.name = preprocessorOutput;
  323. }
  324. }
  325. }
  326. }
  327. }
  328. }
  329. _loadModel(model, scope, group, weights) {
  330. this._groups = this._groups | (group.length > 0 ? true : false);
  331. const description = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : '';
  332. switch (model.Type) {
  333. case 'neuralNetworkClassifier': {
  334. const neuralNetworkClassifier = model.neuralNetworkClassifier;
  335. for (const layer of neuralNetworkClassifier.layers) {
  336. this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
  337. }
  338. this._updateClassifierOutput(group, neuralNetworkClassifier);
  339. this._updatePreprocessing(scope, group, neuralNetworkClassifier.preprocessing);
  340. return 'Neural Network Classifier';
  341. }
  342. case 'neuralNetwork': {
  343. const neuralNetwork = model.neuralNetwork;
  344. for (const layer of neuralNetwork.layers) {
  345. this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
  346. }
  347. this._updatePreprocessing(scope, group, neuralNetwork.preprocessing);
  348. return 'Neural Network';
  349. }
  350. case 'neuralNetworkRegressor': {
  351. const neuralNetworkRegressor = model.neuralNetworkRegressor;
  352. for (const layer of neuralNetworkRegressor.layers) {
  353. this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
  354. }
  355. this._updatePreprocessing(scope, group, neuralNetworkRegressor);
  356. return 'Neural Network Regressor';
  357. }
  358. case 'pipeline': {
  359. for (let i = 0; i < model.pipeline.models.length; i++) {
  360. this._loadModel(model.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipeline[' + i.toString() + ']');
  361. }
  362. return 'Pipeline';
  363. }
  364. case 'pipelineClassifier': {
  365. for (let i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) {
  366. this._loadModel(model.pipelineClassifier.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineClassifier[' + i.toString() + ']');
  367. }
  368. return 'Pipeline Classifier';
  369. }
  370. case 'pipelineRegressor': {
  371. for (let i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) {
  372. this._loadModel(model.pipelineRegressor.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineRegressor[' + i.toString() + ']');
  373. }
  374. return 'Pipeline Regressor';
  375. }
  376. case 'glmClassifier': {
  377. this._createNode(scope, group, 'glmClassifier', null, description,
  378. {
  379. classEncoding: model.glmClassifier.classEncoding,
  380. offset: model.glmClassifier.offset,
  381. weights: model.glmClassifier.weights
  382. },
  383. [ model.description.input[0].name ],
  384. [ model.description.predictedProbabilitiesName ]);
  385. this._updateClassifierOutput(group, model.glmClassifier);
  386. return 'Generalized Linear Classifier';
  387. }
  388. case 'glmRegressor': {
  389. this._createNode(scope, group, 'glmRegressor', null, description,
  390. model.glmRegressor,
  391. [ model.description.input[0].name ],
  392. [ model.description.output[0].name ]);
  393. return 'Generalized Linear Regressor';
  394. }
  395. case 'dictVectorizer': {
  396. this._createNode(scope, group, 'dictVectorizer', null, description,
  397. model.dictVectorizer,
  398. [ model.description.input[0].name ],
  399. [ model.description.output[0].name ]);
  400. return 'Dictionary Vectorizer';
  401. }
  402. case 'featureVectorizer': {
  403. this._createNode(scope, group, 'featureVectorizer', null, description,
  404. model.featureVectorizer,
  405. coreml.Graph._formatFeatureDescriptionList(model.description.input),
  406. [ model.description.output[0].name ]);
  407. return 'Feature Vectorizer';
  408. }
  409. case 'treeEnsembleClassifier': {
  410. this._createNode(scope, group, 'treeEnsembleClassifier', null, description,
  411. model.treeEnsembleClassifier.treeEnsemble,
  412. [ model.description.input[0].name ],
  413. [ model.description.output[0].name ]);
  414. this._updateClassifierOutput(group, model.treeEnsembleClassifier);
  415. return 'Tree Ensemble Classifier';
  416. }
  417. case 'treeEnsembleRegressor': {
  418. this._createNode(scope, group, 'treeEnsembleRegressor', null, description,
  419. model.treeEnsembleRegressor.treeEnsemble,
  420. [ model.description.input[0].name ],
  421. [ model.description.output[0].name ]);
  422. return 'Tree Ensemble Regressor';
  423. }
  424. case 'supportVectorClassifier': {
  425. this._createNode(scope, group, 'supportVectorClassifier', null, description,
  426. {
  427. coefficients: model.supportVectorClassifier.coefficients,
  428. denseSupportVectors: model.supportVectorClassifier.denseSupportVectors,
  429. kernel: model.supportVectorClassifier.kernel,
  430. numberOfSupportVectorsPerClass: model.supportVectorClassifier.numberOfSupportVectorsPerClass,
  431. probA: model.supportVectorClassifier.probA,
  432. probB: model.supportVectorClassifier.probB,
  433. rho: model.supportVectorClassifier.rho,
  434. supportVectors: model.supportVectorClassifier.supportVectors
  435. },
  436. [ model.description.input[0].name ],
  437. [ model.description.output[0].name ]);
  438. this._updateClassifierOutput(group, model.supportVectorClassifier);
  439. return 'Support Vector Classifier';
  440. }
  441. case 'supportVectorRegressor': {
  442. this._createNode(scope, group, 'supportVectorRegressor', null, description,
  443. {
  444. coefficients: model.supportVectorRegressor.coefficients,
  445. kernel: model.supportVectorRegressor.kernel,
  446. rho: model.supportVectorRegressor.rho,
  447. supportVectors: model.supportVectorRegressor.supportVectors
  448. },
  449. [ model.description.input[0].name ],
  450. [ model.description.output[0].name ]);
  451. return 'Support Vector Regressor';
  452. }
  453. case 'arrayFeatureExtractor': {
  454. this._createNode(scope, group, 'arrayFeatureExtractor', null, description,
  455. { extractIndex: model.arrayFeatureExtractor.extractIndex },
  456. [ model.description.input[0].name ],
  457. [ model.description.output[0].name ]);
  458. return 'Array Feature Extractor';
  459. }
  460. case 'oneHotEncoder': {
  461. const categoryType = model.oneHotEncoder.CategoryType;
  462. const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse };
  463. oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType];
  464. this._createNode(scope, group, 'oneHotEncoder', null, description,
  465. oneHotEncoderParams,
  466. [ model.description.input[0].name ],
  467. [ model.description.output[0].name ]);
  468. return 'One Hot Encoder';
  469. }
  470. case 'imputer': {
  471. const imputedValue = model.imputer.ImputedValue;
  472. const replaceValue = model.imputer.ReplaceValue;
  473. const imputerParams = {};
  474. imputerParams[imputedValue] = model.imputer[imputedValue];
  475. imputerParams[replaceValue] = model.imputer[replaceValue];
  476. this._createNode(scope, group, 'oneHotEncoder', null, description,
  477. imputerParams,
  478. [ model.description.input[0].name ],
  479. [ model.description.output[0].name ]);
  480. return 'Imputer';
  481. }
  482. case 'normalizer': {
  483. this._createNode(scope, group, 'normalizer', null, description,
  484. model.normalizer,
  485. [ model.description.input[0].name ],
  486. [ model.description.output[0].name ]);
  487. return 'Normalizer';
  488. }
  489. case 'wordTagger': {
  490. this._createNode(scope, group, 'wordTagger', null, description,
  491. model.wordTagger,
  492. [ model.description.input[0].name ],
  493. [
  494. model.wordTagger.tokensOutputFeatureName,
  495. model.wordTagger.tokenTagsOutputFeatureName,
  496. model.wordTagger.tokenLocationsOutputFeatureName,
  497. model.wordTagger.tokenLengthsOutputFeatureName
  498. ]);
  499. return 'Word Tagger';
  500. }
  501. case 'textClassifier': {
  502. this._createNode(scope, group, 'textClassifier', null, description,
  503. model.textClassifier,
  504. [ model.description.input[0].name ],
  505. [ model.description.output[0].name ]);
  506. return 'Text Classifier';
  507. }
  508. case 'nonMaximumSuppression': {
  509. const nonMaximumSuppressionParams = {
  510. pickTop: model.nonMaximumSuppression.pickTop,
  511. stringClassLabels: model.nonMaximumSuppression.stringClassLabels,
  512. iouThreshold: model.nonMaximumSuppression.iouThreshold,
  513. confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold
  514. };
  515. this._createNode(scope, group, 'nonMaximumSuppression', null, description,
  516. nonMaximumSuppressionParams,
  517. [
  518. model.nonMaximumSuppression.confidenceInputFeatureName,
  519. model.nonMaximumSuppression.coordinatesInputFeatureName,
  520. model.nonMaximumSuppression.iouThresholdInputFeatureName,
  521. model.nonMaximumSuppression.confidenceThresholdInputFeatureName,
  522. ],
  523. [
  524. model.nonMaximumSuppression.confidenceOutputFeatureName,
  525. model.nonMaximumSuppression.coordinatesOutputFeatureName
  526. ]);
  527. return 'Non Maximum Suppression';
  528. }
  529. case 'visionFeaturePrint': {
  530. const visionFeaturePrintParams = {
  531. scene: model.visionFeaturePrint.scene
  532. };
  533. this._createNode(scope, group, 'visionFeaturePrint', null, description,
  534. visionFeaturePrintParams,
  535. [ model.description.input[0].name ],
  536. [ model.description.output[0].name ]);
  537. return 'Vision Feature Print';
  538. }
  539. case 'soundAnalysisPreprocessing': {
  540. this._createNode(scope, group, 'soundAnalysisPreprocessing', null, description,
  541. model.soundAnalysisPreprocessing,
  542. [ model.description.input[0].name ],
  543. [ model.description.output[0].name ]);
  544. return 'Sound Analysis Preprocessing';
  545. }
  546. case 'kNearestNeighborsClassifier': {
  547. this._createNode(scope, group, 'kNearestNeighborsClassifier', null, description,
  548. model.kNearestNeighborsClassifier,
  549. [ model.description.input[0].name ],
  550. [ model.description.output[0].name ]);
  551. this._updateClassifierOutput(group, model.kNearestNeighborsClassifier);
  552. return 'Nearest Neighbors Classifier';
  553. }
  554. case 'itemSimilarityRecommender': {
  555. this._createNode(scope, group, 'itemSimilarityRecommender', null, description,
  556. {
  557. itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector,
  558. itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities
  559. },
  560. model.description.input.map((feature) => feature.name),
  561. model.description.output.map((feature) => feature.name));
  562. return 'Item Similarity Recommender';
  563. }
  564. case 'linkedModel': {
  565. this._createNode(scope, group, 'linkedModel', null, description,
  566. model.linkedModel.linkedModelFile,
  567. [ model.description.input[0].name ],
  568. [ model.description.output[0].name ]);
  569. return 'Linked Model';
  570. }
  571. case 'customModel': {
  572. this._createNode(scope, group, 'customModel', null, description,
  573. { className: model.customModel.className, parameters: model.customModel.parameters },
  574. [ model.description.input[0].name ],
  575. [ model.description.output[0].name ]);
  576. return 'customModel';
  577. }
  578. case 'mlProgram': {
  579. return this._loadProgram(model.mlProgram, scope, group, weights);
  580. }
  581. }
  582. throw new coreml.Error("Unknown model type '" + JSON.stringify(Object.keys(model)) + "'.");
  583. }
  584. _loadProgram(program, scope, group, weights) {
  585. // TODO: need to handle functions other than main?
  586. const main = program.functions.main;
  587. // TODO: need to handle more than one block specialization?
  588. const block = main.block_specializations.CoreML5;
  589. const convertValue = (value) => {
  590. switch (value.value) {
  591. case 'immediateValue': {
  592. const tensor = value.immediateValue.tensor;
  593. let values = null;
  594. switch (tensor.value) {
  595. case 'ints':
  596. values = tensor.ints.values;
  597. break;
  598. case 'strings':
  599. values = tensor.strings.values;
  600. break;
  601. case 'bools':
  602. values = tensor.bools.values;
  603. break;
  604. case 'floats':
  605. values = tensor.floats.values;
  606. break;
  607. case 'bytes':
  608. values = tensor.bytes.values;
  609. break;
  610. default:
  611. throw new coreml.Error("Unsupported tensor value '" + tensor.value + "'.");
  612. }
  613. return values;
  614. }
  615. case 'blobFileValue': {
  616. const type = coreml.Utility.valueType(value.type);
  617. const blob = value.blobFileValue;
  618. const offset = blob.offset.toNumber();
  619. const file = blob.fileName;
  620. let data = null;
  621. const stream = weights.get(file);
  622. if (stream) {
  623. stream.seek(offset);
  624. const buffer = stream.read(32);
  625. const reader = new coreml.BinaryReader(buffer);
  626. const signature = reader.uint32();
  627. if (signature == 0xdeadbeef) {
  628. reader.uint32(); // dataType
  629. const size = reader.uint64();
  630. stream.seek(reader.uint64());
  631. const length = (type.shape.dimensions || []).reduce((a, b) => a * b, 1);
  632. switch (type.dataType) {
  633. case 'float32': {
  634. const buffer = stream.read(size);
  635. data = new Float32Array(buffer.buffer, buffer.byteOffset, length).slice();
  636. break;
  637. }
  638. case 'float16': {
  639. data = stream.read(size);
  640. break;
  641. }
  642. default:
  643. throw new coreml.Error("Unsupported blob data type '" + type.dataType + "'.");
  644. }
  645. }
  646. }
  647. return new coreml.Tensor('Blob', type, data);
  648. }
  649. }
  650. throw new coreml.Error("Unsupported value '" + value.value + "'.");
  651. };
  652. const args = new Map();
  653. const arg = (name) => {
  654. if (!args.has(name)) {
  655. args.set(name, { name: name, to: [], from: [] });
  656. }
  657. return args.get(name);
  658. };
  659. const operations = block.operations.map((op) => {
  660. const operation = {
  661. type: op.type,
  662. attributes: {}
  663. };
  664. for (const entry of Object.entries(op.attributes)) {
  665. const key = entry[0];
  666. const value = entry[1];
  667. operation.attributes[key] = convertValue(value);
  668. }
  669. operation.inputs = Object.entries(op.inputs).map((entry) => {
  670. const key = entry[0];
  671. const input = entry[1];
  672. const args = input.arguments.map((argument) => {
  673. if (argument.name) {
  674. const value = arg(argument.name);
  675. value.to.push(operation);
  676. return value;
  677. }
  678. return { value: argument.value };
  679. });
  680. return {
  681. name: key,
  682. arguments: args
  683. };
  684. });
  685. operation.outputs = op.outputs.map((output) => {
  686. const value = arg(output.name);
  687. value.type = coreml.Utility.valueType(output.type);
  688. value.from.push(operation);
  689. return {
  690. name: 'output',
  691. arguments: [ value ]
  692. };
  693. });
  694. return operation;
  695. });
  696. for (const op of operations) {
  697. if (op.type === 'const' && op.inputs.length === 0 &&
  698. op.outputs.length === 1 && op.outputs[0].arguments.length === 1) {
  699. const argument = op.outputs[0].arguments[0];
  700. if (op.attributes && op.attributes.val) {
  701. const type = argument.type;
  702. const data = op.attributes.val;
  703. if (data instanceof Uint8Array && data.length === 2 &&
  704. type.dataType === 'float16' && type.shape.dimensions.length === 0) {
  705. const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
  706. argument.value = view.getFloat16(0, true);
  707. }
  708. else {
  709. argument.value = data;
  710. }
  711. argument.const = true;
  712. op.delete = true;
  713. }
  714. }
  715. }
  716. for (const op of operations) {
  717. for (const input of op.inputs) {
  718. if (input.arguments.length > 1 && input.arguments.some((argument) => argument.const)) {
  719. if (input.arguments.every((argument) => argument.value instanceof coreml.Tensor)) {
  720. continue;
  721. }
  722. for (const argument of input.arguments) {
  723. for (const from of argument.from) {
  724. from.delete = false;
  725. }
  726. delete argument.value;
  727. }
  728. }
  729. }
  730. }
  731. for (const op of operations) {
  732. if (op.delete) {
  733. continue;
  734. }
  735. op.inputs = op.inputs.filter((input) => {
  736. if (input.arguments.every((argument) => argument.value === undefined || argument.value instanceof coreml.Tensor)) {
  737. return true;
  738. }
  739. if (input.arguments.length === 1) {
  740. const argument = input.arguments[0];
  741. op.attributes[input.name] = argument.value;
  742. return false;
  743. }
  744. op.attributes[input.name] = input.arguments.map((argument) => argument.value[0]);
  745. return false;
  746. });
  747. }
  748. const tensors = new Map();
  749. const tensor = (arg) => {
  750. if (!tensors.has(arg.name)) {
  751. tensors.set(arg.name, new coreml.Argument(arg.name, arg.type, null, arg.value));
  752. }
  753. return tensors.get(arg.name);
  754. };
  755. for (const op of operations) {
  756. if (op.delete) {
  757. continue;
  758. }
  759. op.inputs = op.inputs.map((input) => new coreml.Parameter(input.name, true, input.arguments.map((argument) => tensor(argument))));
  760. op.outputs = op.outputs.map((output) => new coreml.Parameter(output.name, true, output.arguments.map((argument) => tensor(argument))));
  761. }
  762. for (const op of operations.filter((op) => !op.delete)) {
  763. const type = 'program:' + op.type;
  764. const metadata = this._metadata.type(type);
  765. if (metadata && Array.isArray(metadata.inputs)) {
  766. let index = 1;
  767. const map = new Map(metadata.inputs.map((input) => [ input.name, index++ ]));
  768. op.inputs.sort((a, b) => (map.get(a.name) || map.size) - (map.get(b.name) || map.size));
  769. }
  770. const node = new coreml.Node(this._metadata, group, type, null, null, op.attributes, op.inputs, op.outputs);
  771. this._nodes.push(node);
  772. }
  773. return 'ML Program';
  774. }
  775. _createNode(scope, group, type, name, description, data, inputs, outputs, outputTypes) {
  776. inputs = inputs.map((input) => scope[input] ? scope[input].argument : input);
  777. outputs = outputs.map((output) => {
  778. if (scope[output]) {
  779. scope[output].counter++;
  780. const next = output + '\n' + scope[output].counter.toString(); // custom argument id
  781. scope[output].argument = next;
  782. return next;
  783. }
  784. scope[output] = {
  785. argument: output,
  786. counter: 0
  787. };
  788. return output;
  789. });
  790. const initializers = [];
  791. const attributes = {};
  792. if (data) {
  793. const map = this._initialize(type, data, initializers);
  794. for (const key of Object.keys(data)) {
  795. if (map[key]) {
  796. continue;
  797. }
  798. attributes[key] = data[key];
  799. }
  800. }
  801. const inputParameters = this._metadata.getInputs(type, inputs).map((input) => {
  802. return new coreml.Parameter(input.name, true, input.arguments.map((argument) => {
  803. return new coreml.Argument(argument.name, argument.type, null, null);
  804. }));
  805. });
  806. inputParameters.push(...initializers);
  807. const outputParameters = outputs.map((output, index) => {
  808. const name = this._metadata.getOutputName(type, index);
  809. const outputType = outputTypes ? outputTypes[index] : null;
  810. return new coreml.Parameter(name, true, [ new coreml.Argument(output, outputType, null, null) ]);
  811. });
  812. const node = new coreml.Node(this._metadata, group, type, name, description, attributes, inputParameters, outputParameters);
  813. this._nodes.push(node);
  814. return node;
  815. }
  816. _initializer(type, initializers, kind, name, shape, data) {
  817. let dataType = '?';
  818. let quantization = null;
  819. let values = null;
  820. if (data) {
  821. if (data.floatValue && data.floatValue.length > 0) {
  822. values = data.floatValue;
  823. dataType = 'float32';
  824. }
  825. else if (data.float16Value && data.float16Value.length > 0) {
  826. values = data.float16Value; // byte[]
  827. dataType = 'float16';
  828. }
  829. else if (data.rawValue && data.rawValue.length > 0) {
  830. if (data.quantization) {
  831. values = data.rawValue;
  832. dataType = 'uint' + data.quantization.numberOfBits.toString();
  833. }
  834. else {
  835. shape = [];
  836. }
  837. }
  838. quantization = data.quantization || null;
  839. }
  840. const tensorType = new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  841. const tensor = new coreml.Tensor(kind, tensorType, values, quantization);
  842. const argument = new coreml.Argument('', null, null, tensor);
  843. const visible = this._metadata.visible(type, name);
  844. initializers.push(new coreml.Parameter(name, visible, [ argument ]));
  845. }
  846. _initialize(type, data, initializers) {
  847. switch (type) {
  848. case 'convolution': {
  849. const weightsShape = [ data.outputChannels, data.kernelChannels, data.kernelSize[0], data.kernelSize[1] ];
  850. if (data.isDeconvolution) {
  851. weightsShape[0] = data.kernelChannels;
  852. weightsShape[1] = Math.floor(data.outputChannels / (data.nGroups != 0 ? data.nGroups : 1));
  853. }
  854. this._initializer(type, initializers, 'Weights', 'weights', weightsShape, data.weights);
  855. if (data.hasBias) {
  856. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputChannels ], data.bias);
  857. }
  858. return { 'weights': true, 'bias': data.hasBias };
  859. }
  860. case 'innerProduct':
  861. this._initializer(type, initializers, 'Weights', 'weights', [ data.outputChannels, data.inputChannels ], data.weights);
  862. if (data.hasBias) {
  863. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputChannels ], data.bias);
  864. }
  865. return { 'weights': true, 'bias': data.hasBias };
  866. case 'batchnorm':
  867. this._initializer(type, initializers, 'Weights', 'gamma', [ data.channels ], data.gamma);
  868. this._initializer(type, initializers, 'Weights', 'beta', [ data.channels ], data.beta);
  869. if (data.mean) {
  870. this._initializer(type, initializers, 'Weights', 'mean', [ data.channels ], data.mean);
  871. }
  872. if (data.variance) {
  873. this._initializer(type, initializers, 'Weights', 'variance', [ data.channels ], data.variance);
  874. }
  875. return { 'gamma': true, 'beta': true, 'mean': true, 'variance': true };
  876. case 'embedding':
  877. this._initializer(type, initializers, 'Weights', 'weights', [ data.inputDim, data.outputChannels ], data.weights);
  878. return { 'weights': true };
  879. case 'loadConstant':
  880. case 'loadConstantND':
  881. this._initializer(type, initializers, 'Weights', 'data', data.shape, data.data);
  882. return { 'data': true };
  883. case 'scale':
  884. this._initializer(type, initializers, 'Weights', 'scale', data.shapeScale, data.scale);
  885. if (data.hasBias) {
  886. this._initializer(type, initializers, 'Weights', 'bias', data.shapeBias, data.bias);
  887. }
  888. return { 'scale': true, 'bias': data.hasBias };
  889. case 'bias':
  890. this._initializer(type, initializers, 'Weights', 'bias', data.shape, data.bias);
  891. return { 'bias': true };
  892. case 'simpleRecurrent':
  893. this._initializer(type, initializers, 'Weights', 'weights', [ data.outputVectorSize, data.inputVectorSize ], data.weightMatrix);
  894. this._initializer(type, initializers, 'Weights', 'recurrent', [ data.outputVectorSize, data.inputVectorSize ], data.recursionMatrix);
  895. if (data.hasBiasVectors) {
  896. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputVectorSize ], data.biasVector);
  897. }
  898. return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
  899. case 'gru': {
  900. const recursionMatrixShape = [ data.outputVectorSize, data.outputVectorSize ];
  901. const weightMatrixShape = [ data.outputVectorSize, data.inputVectorSize ];
  902. const biasVectorShape = [ data.outputVectorSize ];
  903. this._initializer(type, initializers, 'Weights', 'updateGateWeightMatrix', weightMatrixShape, data.updateGateWeightMatrix);
  904. this._initializer(type, initializers, 'Weights', 'resetGateWeightMatrix', weightMatrixShape, data.resetGateWeightMatrix);
  905. this._initializer(type, initializers, 'Weights', 'outputGateWeightMatrix', weightMatrixShape, data.outputGateWeightMatrix);
  906. this._initializer(type, initializers, 'Weights', 'updateGateRecursionMatrix', recursionMatrixShape, data.updateGateRecursionMatrix);
  907. this._initializer(type, initializers, 'Weights', 'resetGateRecursionMatrix', recursionMatrixShape, data.resetGateRecursionMatrix);
  908. this._initializer(type, initializers, 'Weights', 'outputGateRecursionMatrix', recursionMatrixShape, data.outputGateRecursionMatrix);
  909. if (data.hasBiasVectors) {
  910. this._initializer(type, initializers, 'Weights', 'updateGateBiasVector', biasVectorShape, data.updateGateBiasVector);
  911. this._initializer(type, initializers, 'Weights', 'resetGateBiasVector', biasVectorShape, data.resetGateBiasVector);
  912. this._initializer(type, initializers, 'Weights', 'outputGateBiasVector', biasVectorShape, data.outputGateBiasVector);
  913. }
  914. return {
  915. 'updateGateWeightMatrix': true, 'resetGateWeightMatrix': true, 'outputGateWeightMatrix': true,
  916. 'updateGateRecursionMatrix': true, 'resetGateRecursionMatrix': true, 'outputGateRecursionMatrix': true,
  917. 'updateGateBiasVector': data.hasBiasVectors, 'resetGateBiasVector': data.hasBiasVectors, 'outputGateBiasVector': data.hasBiasVectors
  918. };
  919. }
  920. case 'uniDirectionalLSTM':
  921. case 'biDirectionalLSTM': {
  922. const count = (type == 'uniDirectionalLSTM') ? 1 : 2;
  923. const matrixShape = [ data.outputVectorSize, data.inputVectorSize ];
  924. const vectorShape = [ data.outputVectorSize ];
  925. for (let i = 0; i < count; i++) {
  926. const weights = count == 1 ? data.weightParams : data.weightParams[i];
  927. const suffix = (i == 0) ? '' : '_rev';
  928. this._initializer(type, initializers, 'Weights', 'inputGateWeightMatrix' + suffix, matrixShape, weights.inputGateWeightMatrix);
  929. this._initializer(type, initializers, 'Weights', 'forgetGateWeightMatrix' + suffix, matrixShape, weights.forgetGateWeightMatrix);
  930. this._initializer(type, initializers, 'Weights', 'blockInputWeightMatrix' + suffix, matrixShape, weights.blockInputWeightMatrix);
  931. this._initializer(type, initializers, 'Weights', 'outputGateWeightMatrix' + suffix, matrixShape, weights.outputGateWeightMatrix);
  932. this._initializer(type, initializers, 'Weights', 'inputGateRecursionMatrix' + suffix, matrixShape, weights.inputGateRecursionMatrix);
  933. this._initializer(type, initializers, 'Weights', 'forgetGateRecursionMatrix' + suffix, matrixShape,weights.forgetGateRecursionMatrix);
  934. this._initializer(type, initializers, 'Weights', 'blockInputRecursionMatrix' + suffix, matrixShape, weights.blockInputRecursionMatrix);
  935. this._initializer(type, initializers, 'Weights', 'outputGateRecursionMatrix' + suffix, matrixShape, weights.outputGateRecursionMatrix);
  936. if (data.params.hasBiasVectors) {
  937. this._initializer(type, initializers, 'Weights', 'inputGateBiasVector' + suffix, vectorShape, weights.inputGateBiasVector);
  938. this._initializer(type, initializers, 'Weights', 'forgetGateBiasVector' + suffix, vectorShape, weights.forgetGateBiasVector);
  939. this._initializer(type, initializers, 'Weights', 'blockInputBiasVector' + suffix, vectorShape, weights.blockInputBiasVector);
  940. this._initializer(type, initializers, 'Weights', 'outputGateBiasVector' + suffix, vectorShape, weights.outputGateBiasVector);
  941. }
  942. if (data.params.hasPeepholeVectors) {
  943. this._initializer(type, initializers, 'Weights', 'inputGatePeepholeVector' + suffix, vectorShape, weights.inputGatePeepholeVector);
  944. this._initializer(type, initializers, 'Weights', 'forgetGatePeepholeVector' + suffix, vectorShape, weights.forgetGatePeepholeVector);
  945. this._initializer(type, initializers, 'Weights', 'outputGatePeepholeVector' + suffix, vectorShape, weights.outputGatePeepholeVector);
  946. }
  947. }
  948. return { 'weightParams': true };
  949. }
  950. case 'dictVectorizer':
  951. data.stringToIndex = this._convertVector(data.stringToIndex);
  952. return {};
  953. case 'wordTagger':
  954. data.modelParameterData = Array.from(data.modelParameterData);
  955. data.stringTags = this._convertVector(data.stringTags);
  956. return { tokensOutputFeatureName: true, tokenTagsOutputFeatureName: true, tokenLengthsOutputFeatureName: true, tokenLocationsOutputFeatureName: true };
  957. case 'textClassifier':
  958. data.modelParameterData = Array.from(data.modelParameterData);
  959. data.stringClassLabels = this._convertVector(data.stringClassLabels);
  960. return {};
  961. case 'nonMaximumSuppression':
  962. data.stringClassLabels = this._convertVector(data.stringClassLabels);
  963. return {};
  964. }
  965. return {};
  966. }
  967. _convertVector(value) {
  968. if (value && Object.keys(value).length == 1 && value.vector) {
  969. return value.vector;
  970. }
  971. return value;
  972. }
  973. static _formatFeatureDescriptionList(list) {
  974. return list.map((item) => item.name);
  975. }
  976. };
  977. coreml.Parameter = class {
  978. constructor(name, visible, args) {
  979. this._name = name;
  980. this._visible = visible;
  981. this._arguments = args;
  982. }
  983. get name() {
  984. return this._name;
  985. }
  986. get visible() {
  987. return this._visible;
  988. }
  989. get arguments() {
  990. return this._arguments;
  991. }
  992. };
  993. coreml.Argument = class {
  994. constructor(name, type, description, initializer) {
  995. if (typeof name !== 'string') {
  996. throw new coreml.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  997. }
  998. this._name = name;
  999. this._type = type;
  1000. this._description = description || null;
  1001. this._initializer = initializer || null;
  1002. }
  1003. get name() {
  1004. return this._name;
  1005. }
  1006. set name(value) {
  1007. this._name = value;
  1008. }
  1009. get type() {
  1010. if (this._initializer) {
  1011. return this._initializer.type;
  1012. }
  1013. return this._type;
  1014. }
  1015. get description() {
  1016. return this._description;
  1017. }
  1018. get quantization() {
  1019. if (this._initializer) {
  1020. return this._initializer.quantization;
  1021. }
  1022. return null;
  1023. }
  1024. get initializer() {
  1025. return this._initializer;
  1026. }
  1027. };
  1028. coreml.Node = class {
  1029. constructor(metadata, group, type, name, description, attributes, inputs, outputs) {
  1030. if (!type) {
  1031. throw new Error('Undefined node type.');
  1032. }
  1033. if (group) {
  1034. this._group = group;
  1035. }
  1036. this._type = Object.assign({}, metadata.type(type) || { name: type });
  1037. this._type.name = type.split(':').pop();
  1038. this._name = name || '';
  1039. this._description = description || '';
  1040. this._inputs = inputs;
  1041. this._outputs = outputs;
  1042. this._attributes = [];
  1043. if (attributes) {
  1044. for (const key of Object.keys(attributes)) {
  1045. const schema = metadata.attribute(type, key);
  1046. const value = attributes[key];
  1047. const attribute = new coreml.Attribute(schema, key, value);
  1048. this._attributes.push(attribute);
  1049. }
  1050. }
  1051. }
  1052. get type() {
  1053. return this._type;
  1054. }
  1055. get name() {
  1056. return this._name;
  1057. }
  1058. get description() {
  1059. return this._description;
  1060. }
  1061. get metadata() {
  1062. return this._metadata;
  1063. }
  1064. get group() {
  1065. return this._group ? this._group : null;
  1066. }
  1067. get inputs() {
  1068. return this._inputs;
  1069. }
  1070. get outputs() {
  1071. return this._outputs;
  1072. }
  1073. get attributes() {
  1074. return this._attributes;
  1075. }
  1076. };
  1077. coreml.Attribute = class {
  1078. constructor(metadata, name, value) {
  1079. this._name = name;
  1080. this._value = value;
  1081. if (this._value instanceof coreml.Tensor) {
  1082. this._type = 'tensor';
  1083. }
  1084. if (metadata) {
  1085. if (metadata.type) {
  1086. this._type = metadata.type;
  1087. }
  1088. if (this._type && coreml.proto) {
  1089. this._value = coreml.Utility.enum(this._type, this._value);
  1090. }
  1091. if (Object.prototype.hasOwnProperty.call(metadata, 'visible') && !metadata.visible) {
  1092. this._visible = false;
  1093. }
  1094. else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  1095. if (Array.isArray(value)) {
  1096. value = value.map((item) => item.toNumber());
  1097. }
  1098. if (JSON.stringify(metadata.default) == JSON.stringify(value)) {
  1099. this._visible = false;
  1100. }
  1101. }
  1102. }
  1103. }
  1104. get name() {
  1105. return this._name;
  1106. }
  1107. get type() {
  1108. return this._type;
  1109. }
  1110. get value() {
  1111. return this._value;
  1112. }
  1113. get visible() {
  1114. return this._visible == false ? false : true;
  1115. }
  1116. };
  1117. coreml.Tensor = class {
  1118. constructor(kind, type, data, quantization) {
  1119. this._kind = kind;
  1120. this._type = type;
  1121. this._data = data;
  1122. this._quantization = quantization;
  1123. }
  1124. get kind() {
  1125. return this._kind;
  1126. }
  1127. get type() {
  1128. return this._type;
  1129. }
  1130. get quantization() {
  1131. if (this._quantization) {
  1132. if (this._quantization.lookupTableQuantization &&
  1133. this._quantization.lookupTableQuantization.floatValue &&
  1134. this._quantization.lookupTableQuantization.floatValue.length > 0) {
  1135. const map = [];
  1136. for (const key of Object.keys(this._quantization.lookupTableQuantization.floatValue)) {
  1137. map.push(key.toString() + ' = ' + this._quantization.lookupTableQuantization.floatValue[key].toString());
  1138. }
  1139. return map.join('; ');
  1140. }
  1141. return '?';
  1142. }
  1143. return null;
  1144. }
  1145. get state() {
  1146. return this._context().state;
  1147. }
  1148. get value() {
  1149. const context = this._context();
  1150. if (context.state) {
  1151. return null;
  1152. }
  1153. context.limit = Number.MAX_SAFE_INTEGER;
  1154. return this._decode(context, 0);
  1155. }
  1156. toString() {
  1157. const context = this._context();
  1158. if (context.state) {
  1159. return '';
  1160. }
  1161. context.limit = 10000;
  1162. const value = this._decode(context, 0);
  1163. return JSON.stringify(value, null, 4);
  1164. }
  1165. _context() {
  1166. const context = {};
  1167. context.state = null;
  1168. context.index = 0;
  1169. context.count = 0;
  1170. context.dataType = this._type.dataType;
  1171. context.dimensions = this._type.shape.dimensions;
  1172. if (!this._data) {
  1173. context.state = 'Tensor data is empty.';
  1174. return context;
  1175. }
  1176. switch (context.dataType) {
  1177. case 'float32':
  1178. context.data = this._data;
  1179. break;
  1180. case 'float16':
  1181. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  1182. break;
  1183. default:
  1184. if (this._quantization) {
  1185. context.dataType = 'quantization';
  1186. context.bits = this._quantization.numberOfBits.toNumber();
  1187. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  1188. }
  1189. else {
  1190. context.state = 'Tensor data type is not implemented.';
  1191. }
  1192. break;
  1193. }
  1194. return context;
  1195. }
  1196. _decode(context, dimension) {
  1197. const results = [];
  1198. const size = context.dimensions[dimension];
  1199. if (dimension == context.dimensions.length - 1) {
  1200. for (let i = 0; i < size; i++) {
  1201. if (context.count > context.limit) {
  1202. results.push('...');
  1203. return results;
  1204. }
  1205. switch (context.dataType) {
  1206. case 'float32':
  1207. results.push(this._data[context.index]);
  1208. context.index++;
  1209. break;
  1210. case 'float16':
  1211. results.push(context.data.getFloat16(context.index, true));
  1212. context.index += 2;
  1213. break;
  1214. case 'quantization':
  1215. results.push(context.data.getBits(context.index, context.bits));
  1216. context.index++;
  1217. break;
  1218. }
  1219. context.count++;
  1220. }
  1221. }
  1222. else {
  1223. for (let j = 0; j < size; j++) {
  1224. if (context.count > context.limit) {
  1225. results.push('...');
  1226. return results;
  1227. }
  1228. results.push(this._decode(context, dimension + 1));
  1229. }
  1230. }
  1231. return results;
  1232. }
  1233. };
  1234. coreml.TensorType = class {
  1235. constructor(dataType, shape) {
  1236. this._dataType = dataType;
  1237. this._shape = shape || new coreml.TensorShape([]);
  1238. }
  1239. get dataType() {
  1240. return this._dataType;
  1241. }
  1242. get shape() {
  1243. return this._shape;
  1244. }
  1245. toString() {
  1246. return this.dataType + this._shape.toString();
  1247. }
  1248. };
  1249. coreml.TensorShape = class {
  1250. constructor(dimensions) {
  1251. this._dimensions = dimensions;
  1252. }
  1253. get dimensions() {
  1254. return this._dimensions;
  1255. }
  1256. toString() {
  1257. if (!this._dimensions || this._dimensions.length == 0) {
  1258. return '';
  1259. }
  1260. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  1261. }
  1262. };
  1263. coreml.ListType = class {
  1264. constructor(elementType) {
  1265. this._elementType = elementType;
  1266. }
  1267. toString() {
  1268. return 'list<' + this._elementType.toString() + '>';
  1269. }
  1270. };
  1271. coreml.MapType = class {
  1272. constructor(keyType, valueType) {
  1273. this._keyType = keyType;
  1274. this._valueType = valueType;
  1275. }
  1276. get keyType() {
  1277. return this._keyType;
  1278. }
  1279. get valueType() {
  1280. return this._valueType;
  1281. }
  1282. toString() {
  1283. return 'map<' + this._keyType + ',' + this._valueType.toString() + '>';
  1284. }
  1285. };
  1286. coreml.ImageType = class {
  1287. constructor(colorSpace, width, height) {
  1288. this._colorSpace = '?';
  1289. switch (colorSpace) {
  1290. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE:
  1291. this._colorSpace = 'Grayscale';
  1292. break;
  1293. case coreml.proto.ImageFeatureType.ColorSpace.RGB:
  1294. this._colorSpace = 'RGB';
  1295. break;
  1296. case coreml.proto.ImageFeatureType.ColorSpace.BGR:
  1297. this._colorSpace = 'BGR';
  1298. break;
  1299. }
  1300. this._width = width;
  1301. this._height = height;
  1302. }
  1303. toString() {
  1304. return 'image<' + this._colorSpace + ',' + this._width. toString() + 'x' + this._height.toString() + '>';
  1305. }
  1306. };
  1307. coreml.OptionalType = class {
  1308. constructor(type) {
  1309. this._type = type;
  1310. }
  1311. toString() {
  1312. return this._type.toString() + '?';
  1313. }
  1314. };
  1315. coreml.BinaryReader = class {
  1316. constructor(buffer) {
  1317. this._buffer = buffer;
  1318. this._position = 0;
  1319. this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  1320. }
  1321. skip(offset) {
  1322. const position = this._position;
  1323. this._position += offset;
  1324. if (this._position > this._length) {
  1325. throw new Error('Expected ' + (this._position - this._length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  1326. }
  1327. return position;
  1328. }
  1329. uint32() {
  1330. const position = this.skip(4);
  1331. return this._dataView.getUint32(position, true);
  1332. }
  1333. uint64() {
  1334. const position = this.skip(8);
  1335. return this._dataView.getUint64(position, true).toNumber();
  1336. }
  1337. };
  1338. coreml.Utility = class {
  1339. static enum(name, value) {
  1340. let type = coreml.proto;
  1341. const parts = name.split('.');
  1342. while (type && parts.length > 0) {
  1343. type = type[parts.shift()];
  1344. }
  1345. if (type) {
  1346. coreml.Utility._enumKeyMap = coreml.Utility._enumKeyMap || new Map();
  1347. if (!coreml.Utility._enumKeyMap.has(name)) {
  1348. const map = new Map(Object.entries(type).map((pair) => [ pair[1], pair[0] ]));
  1349. coreml.Utility._enumKeyMap.set(name, map);
  1350. }
  1351. const map = coreml.Utility._enumKeyMap.get(name);
  1352. if (map.has(value)) {
  1353. return map.get(value);
  1354. }
  1355. }
  1356. return value;
  1357. }
  1358. static featureType(type) {
  1359. let result = '?';
  1360. if (type) {
  1361. switch (type.Type) {
  1362. case 'multiArrayType': {
  1363. let shape = new coreml.TensorShape([]);
  1364. if (type.multiArrayType.shape && type.multiArrayType.shape.length > 0) {
  1365. shape = new coreml.TensorShape(type.multiArrayType.shape);
  1366. }
  1367. let dataType = '?';
  1368. switch (type.multiArrayType.dataType) {
  1369. case coreml.proto.ArrayFeatureType.ArrayDataType.FLOAT32:
  1370. dataType = 'float32';
  1371. break;
  1372. case coreml.proto.ArrayFeatureType.ArrayDataType.INT32:
  1373. dataType = 'int32';
  1374. break;
  1375. case coreml.proto.ArrayFeatureType.ArrayDataType.DOUBLE:
  1376. dataType = 'float64';
  1377. break;
  1378. }
  1379. result = new coreml.TensorType(dataType, shape);
  1380. break;
  1381. }
  1382. case 'stringType': {
  1383. result = new coreml.TensorType('string');
  1384. break;
  1385. }
  1386. case 'doubleType': {
  1387. result = new coreml.TensorType('float64');
  1388. break;
  1389. }
  1390. case 'int64Type': {
  1391. result = new coreml.TensorType('int64');
  1392. break;
  1393. }
  1394. case 'dictionaryType': {
  1395. result = new coreml.MapType(type.dictionaryType.KeyType.replace('KeyType', ''), 'float64');
  1396. break;
  1397. }
  1398. case 'imageType': {
  1399. result = new coreml.ImageType(type.imageType.colorSpace, type.imageType.width, type.imageType.height);
  1400. break;
  1401. }
  1402. }
  1403. if (type.isOptional) {
  1404. result = new coreml.OptionalType(result);
  1405. }
  1406. }
  1407. return result;
  1408. }
  1409. static tensorType(type) {
  1410. if (!coreml.Utility._dataTypes) {
  1411. coreml.Utility._dataTypes = new Map();
  1412. const DataType = coreml.proto.MILSpec.DataType;
  1413. for (const pair of Object.entries(DataType)) {
  1414. if (pair[0] === 'UNUSED_TYPE') {
  1415. continue;
  1416. }
  1417. const name = pair[0] === 'bool' ? 'boolean' : pair[0].toLowerCase();
  1418. coreml.Utility._dataTypes.set(pair[1], name);
  1419. }
  1420. }
  1421. const shape = (type.dimensions.map(dim => dim.constant ? dim.constant.size : '?'));
  1422. const dataType = coreml.Utility._dataTypes.get(type.dataType);
  1423. if (dataType === null) {
  1424. throw new coreml.Error("Unsupported data type '" + type.dataType + "'.");
  1425. }
  1426. return new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  1427. }
  1428. static valueType(type) {
  1429. switch (type.type) {
  1430. case 'tensorType':
  1431. return coreml.Utility.tensorType(type.tensorType);
  1432. case 'listType':
  1433. return new coreml.ListType(coreml.Utility.valueType(type.listType.type));
  1434. default:
  1435. throw new coreml.Error("Unsupported value type '" + type.type + "'.");
  1436. }
  1437. }
  1438. };
  1439. coreml.Metadata = class {
  1440. static open(context) {
  1441. if (coreml.Metadata._metadata) {
  1442. return Promise.resolve(coreml.Metadata._metadata);
  1443. }
  1444. return context.request('coreml-metadata.json', 'utf-8', null).then((data) => {
  1445. coreml.Metadata._metadata = new coreml.Metadata(data);
  1446. return coreml.Metadata._metadata;
  1447. }).catch(() => {
  1448. coreml.Metadata._metadata = new coreml.Metadata(null);
  1449. return coreml.Metadata._metadata;
  1450. });
  1451. }
  1452. constructor(data) {
  1453. this._map = new Map();
  1454. this._attributeCache = new Map();
  1455. this._inputCache = new Map();
  1456. if (data) {
  1457. const metadata = JSON.parse(data);
  1458. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  1459. }
  1460. }
  1461. type(name) {
  1462. return this._map.get(name);
  1463. }
  1464. attribute(type, name) {
  1465. const key = type + ':' + name;
  1466. if (!this._attributeCache.has(key)) {
  1467. this._attributeCache.set(key, null);
  1468. const metadata = this.type(type);
  1469. if (metadata && Array.isArray(metadata.attributes) && metadata.attributes.length > 0) {
  1470. for (const attribute of metadata.attributes) {
  1471. this._attributeCache.set(type + ':' + attribute.name, attribute);
  1472. }
  1473. }
  1474. }
  1475. return this._attributeCache.get(key);
  1476. }
  1477. visible(type, name) {
  1478. const key = type + ':' + name;
  1479. if (!this._inputCache.has(key)) {
  1480. this._inputCache.set(key, null);
  1481. const metadata = this.type(type);
  1482. if (metadata && Array.isArray(metadata.inputs) && metadata.inputs.length > 0) {
  1483. for (const input of metadata.inputs) {
  1484. this._inputCache.set(type + ':' + input.name, input);
  1485. }
  1486. }
  1487. }
  1488. const input = this._inputCache.get(key);
  1489. if (input) {
  1490. return input.visible === false ? false : true;
  1491. }
  1492. return true;
  1493. }
  1494. getInputs(type, inputs) {
  1495. const results = [];
  1496. const schema = this._map.get(type);
  1497. let index = 0;
  1498. while (index < inputs.length) {
  1499. const result = { arguments: [] };
  1500. let count = 1;
  1501. let name = null;
  1502. if (schema && schema.inputs) {
  1503. if (index < schema.inputs.length) {
  1504. const input = schema.inputs[index];
  1505. name = input.name;
  1506. if (schema.inputs[index].option == 'variadic') {
  1507. count = inputs.length - index;
  1508. }
  1509. }
  1510. }
  1511. else {
  1512. if (index == 0) {
  1513. name = 'input';
  1514. }
  1515. }
  1516. result.name = name ? name : '(' + index.toString() + ')';
  1517. const array = inputs.slice(index, index + count);
  1518. for (let j = 0; j < array.length; j++) {
  1519. result.arguments.push({ name: array[j] });
  1520. }
  1521. index += count;
  1522. results.push(result);
  1523. }
  1524. return results;
  1525. }
  1526. getOutputName(type, index) {
  1527. const schema = this._map.get(type);
  1528. if (schema) {
  1529. const outputs = schema.outputs;
  1530. if (outputs && index < outputs.length) {
  1531. const output = outputs[index];
  1532. if (output) {
  1533. const name = output.name;
  1534. if (name) {
  1535. return name;
  1536. }
  1537. }
  1538. }
  1539. }
  1540. if (index == 0) {
  1541. return 'output';
  1542. }
  1543. return '(' + index.toString() + ')';
  1544. }
  1545. };
  1546. coreml.Error = class extends Error {
  1547. constructor(message) {
  1548. super(message);
  1549. this.name = 'Error loading Core ML model.';
  1550. }
  1551. };
  1552. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1553. module.exports.ModelFactory = coreml.ModelFactory;
  1554. }