coreml.js 66 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618
  1. /* jshint esversion: 6 */
  2. var coreml = coreml || {};
  3. var json = json || require('./json');
  4. var protobuf = protobuf || require('./protobuf');
  5. coreml.ModelFactory = class {
  6. match(context) {
  7. const tags = context.tags('pb');
  8. if (tags.get(1) === 0 && tags.get(2) === 2) {
  9. return true;
  10. }
  11. const stream = context.stream;
  12. const identifier = context.identifier.toLowerCase();
  13. const extension = identifier.split('.').pop().toLowerCase();
  14. switch (identifier) {
  15. case 'manifest.json': {
  16. const obj = context.open('json');
  17. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  18. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  19. if (entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel').length === 1)){
  20. return true;
  21. }
  22. }
  23. break;
  24. }
  25. case 'metadata.json': {
  26. const obj = context.open('json');
  27. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  28. return true;
  29. }
  30. break;
  31. }
  32. case 'featuredescriptions.json': {
  33. const obj = context.open('json');
  34. if (obj && (obj.Inputs || obj.Outputs)) {
  35. return true;
  36. }
  37. break;
  38. }
  39. }
  40. if (extension === 'bin' && stream.length > 16) {
  41. const buffer = stream.peek(Math.min(256, stream.length));
  42. for (let i = 0; i < buffer.length - 4; i++) {
  43. const signature = (buffer[i] | buffer[i + 1] << 8 | buffer[i + 2] << 16 | buffer [i + 3] << 24) >>> 0;
  44. if (signature === 0xdeadbeef) {
  45. return true;
  46. }
  47. }
  48. }
  49. return false;
  50. }
  51. open(context) {
  52. return context.require('./coreml-proto').then(() => {
  53. return coreml.Metadata.open(context).then((metadata) => {
  54. const openModel = (stream, context, path, format) => {
  55. let model = null;
  56. try {
  57. coreml.proto = protobuf.get('coreml').CoreML.Specification;
  58. const reader = protobuf.BinaryReader.open(stream);
  59. model = coreml.proto.Model.decode(reader);
  60. }
  61. catch (error) {
  62. const message = error && error.message ? error.message : error.toString();
  63. throw new coreml.Error('File format is not coreml.Model (' + message.replace(/\.$/, '') + ').');
  64. }
  65. const weightPaths = new Set();
  66. const walkProgram = (program) => {
  67. for (const pair of Object.entries(program.functions)) {
  68. const func = pair[1];
  69. for (const pair of Object.entries(func.block_specializations)) {
  70. const block = pair[1];
  71. for (const operation of block.operations) {
  72. for (const pair of Object.entries(operation.attributes)) {
  73. const value = pair[1];
  74. if (value.blobFileValue && value.blobFileValue.fileName) {
  75. weightPaths.add(value.blobFileValue.fileName);
  76. }
  77. }
  78. }
  79. }
  80. }
  81. };
  82. const walkModel = (model) => {
  83. if (model.mlProgram) {
  84. walkProgram(model.mlProgram);
  85. }
  86. if (model.pipeline && model.pipeline.models) {
  87. for (const node of model.pipeline.models) {
  88. walkModel(node);
  89. }
  90. }
  91. if (model.pipelineClassifier && model.pipelineClassifier.pipeline && model.pipelineClassifier.pipeline.models) {
  92. for (const node of model.pipelineClassifier.pipeline.models) {
  93. walkModel(node);
  94. }
  95. }
  96. if (model.pipelineRegressor && model.pipelineRegressor.pipeline && model.pipelineRegressor.pipeline.models) {
  97. for (const node of model.pipelineRegressor.pipeline.models) {
  98. walkModel(node);
  99. }
  100. }
  101. };
  102. walkModel(model);
  103. if (weightPaths.size > 0) {
  104. const items = path.split('/');
  105. items.pop();
  106. const folder = items.join('/');
  107. const keys = Array.from(weightPaths);
  108. const paths = keys.map((path) => {
  109. const items = path.split('/');
  110. if (items[0] === '@model_path') {
  111. items[0] = folder;
  112. }
  113. return items.join('/');
  114. });
  115. const promises = paths.map((path) => context.request(path, null));
  116. return Promise.all(promises).then((streams) => {
  117. const weights = new Map();
  118. for (let i = 0; i < keys.length; i++) {
  119. weights.set(keys[i], streams[i]);
  120. }
  121. return new coreml.Model(metadata, format, model, weights);
  122. }).catch((/* err */) => {
  123. return new coreml.Model(metadata, format, model, new Map());
  124. });
  125. }
  126. return new coreml.Model(metadata, format, model, new Map());
  127. };
  128. const openManifest = (obj, context, path) => {
  129. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  130. const entry = entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel'))[0];
  131. const file = path + 'Data/' + entry.path;
  132. return context.request(file, null).then((stream) => {
  133. return openModel(stream, context, file, 'Core ML Package');
  134. });
  135. };
  136. const openManifestStream = (context, path) => {
  137. return context.request(path + 'Manifest.json', null).then((stream) => {
  138. const reader = json.TextReader.open(stream);
  139. const obj = reader.read();
  140. return openManifest(obj, context, path);
  141. });
  142. };
  143. const tags = context.tags('pb');
  144. if (tags.get(1) === 0 && tags.get(2) === 2) {
  145. return openModel(context.stream, context, context.identifier);
  146. }
  147. const identifier = context.identifier.toLowerCase();
  148. switch (identifier) {
  149. case 'manifest.json': {
  150. const obj = context.open('json');
  151. return openManifest(obj, context, '');
  152. }
  153. case 'featuredescriptions.json':
  154. case 'metadata.json': {
  155. return openManifestStream(context, '../../');
  156. }
  157. default: {
  158. const extension = identifier.split('.').pop().toLowerCase();
  159. if (extension === 'bin') {
  160. return openManifestStream(context, '../../../');
  161. }
  162. }
  163. }
  164. });
  165. });
  166. }
  167. };
  168. coreml.Model = class {
  169. constructor(metadata, format, model, weights) {
  170. this._format = (format || 'Core ML') + ' v' + model.specificationVersion.toString();
  171. this._graphs = [ new coreml.Graph(metadata, model, weights) ];
  172. if (model.description && model.description.metadata) {
  173. const properties = model.description.metadata;
  174. if (properties.versionString) {
  175. this._version = properties.versionString;
  176. }
  177. if (properties.author) {
  178. this._author = properties.author;
  179. }
  180. if (properties.shortDescription) {
  181. this._description = properties.shortDescription;
  182. }
  183. if (properties.license) {
  184. this._license = properties.license;
  185. }
  186. if (metadata.userDefined && Object.keys(properties.userDefined).length > 0) {
  187. /* empty */
  188. }
  189. }
  190. }
  191. get format() {
  192. return this._format;
  193. }
  194. get version() {
  195. return this._version || null;
  196. }
  197. get description() {
  198. return this._description || null;
  199. }
  200. get author() {
  201. return this._author || null;
  202. }
  203. get license() {
  204. return this._license || null;
  205. }
  206. get graphs() {
  207. return this._graphs;
  208. }
  209. };
  210. coreml.Graph = class {
  211. constructor(metadata, model, weights) {
  212. this._metadata = metadata;
  213. this._description = model.description;
  214. this._groups = false;
  215. this._inputs = [];
  216. this._outputs = [];
  217. this._nodes = [];
  218. if (this._description) {
  219. this._inputs = this._description.input.map((input) => {
  220. const argument = new coreml.Argument(input.name, coreml.Utility.featureType(input.type), input.shortDescription, null);
  221. return new coreml.Parameter(input.name, true, [ argument ]);
  222. });
  223. this._outputs = this._description.output.map((output) => {
  224. const argument = new coreml.Argument(output.name, coreml.Utility.featureType(output.type), output.shortDescription, null);
  225. return new coreml.Parameter(output.name, true, [ argument ]);
  226. });
  227. }
  228. this._type = this._loadModel(model, {}, '', weights);
  229. }
  230. get name() {
  231. return '';
  232. }
  233. get type() {
  234. return this._type;
  235. }
  236. get inputs() {
  237. return this._inputs;
  238. }
  239. get outputs() {
  240. return this._outputs;
  241. }
  242. get nodes() {
  243. return this._nodes;
  244. }
  245. get groups() {
  246. return this._groups;
  247. }
  248. _updateOutput(name, newName) {
  249. for (const node of this._nodes) {
  250. for (const output of node.outputs) {
  251. for (const argument of output.arguments) {
  252. if (argument.name === name) {
  253. argument.name = newName;
  254. }
  255. }
  256. }
  257. }
  258. return newName;
  259. }
  260. _updateClassifierOutput(group, classifier) {
  261. let labelProbabilityLayerName = classifier.labelProbabilityLayerName;
  262. if (!labelProbabilityLayerName && this._nodes.length > 0) {
  263. const node = this._nodes.slice(-1).pop();
  264. if (node && node.outputs.length == 1 && node.outputs[0].arguments.length == 1) {
  265. labelProbabilityLayerName = node.outputs[0].arguments[0].name;
  266. }
  267. }
  268. let predictedFeatureName = this._description.predictedFeatureName;
  269. let predictedProbabilitiesName = this._description.predictedProbabilitiesName;
  270. if ((predictedFeatureName || predictedProbabilitiesName) && labelProbabilityLayerName && classifier.ClassLabels) {
  271. predictedFeatureName = predictedFeatureName ? predictedFeatureName : '?';
  272. predictedProbabilitiesName = predictedProbabilitiesName ? predictedProbabilitiesName : '?';
  273. const labelProbabilityInput = this._updateOutput(labelProbabilityLayerName, labelProbabilityLayerName + ':labelProbabilityLayerName');
  274. const type = classifier.ClassLabels;
  275. const inputs = [
  276. new coreml.Parameter('input', true, [ new coreml.Argument(labelProbabilityInput) ])
  277. ];
  278. const outputs = [
  279. new coreml.Parameter('probabilities', true, [ new coreml.Argument(predictedProbabilitiesName) ]),
  280. new coreml.Parameter('feature', true, [ new coreml.Argument(predictedFeatureName) ])
  281. ];
  282. const node = new coreml.Node(this._metadata, this._group, type, null, '', classifier[type], inputs, outputs);
  283. this._nodes.push(node);
  284. }
  285. }
  286. _updatePreprocessing(scope, group, preprocessing) {
  287. if (preprocessing && preprocessing.length > 0) {
  288. const preprocessingInput = this._description.input[0].name;
  289. const inputNodes = [];
  290. for (const node of this._nodes) {
  291. if (node.inputs.some((input) => input.arguments.some((arg) => arg.name == preprocessingInput))) {
  292. inputNodes.push(node);
  293. }
  294. }
  295. let preprocessorOutput = preprocessingInput;
  296. let preprocessorIndex = 0;
  297. for (const p of preprocessing) {
  298. const input = p.featureName ? p.featureName : preprocessorOutput;
  299. preprocessorOutput = preprocessingInput + ':' + preprocessorIndex.toString();
  300. this._createNode(scope, group, p.preprocessor, null, '', p[p.preprocessor], [ input ], [ preprocessorOutput ]);
  301. preprocessorIndex++;
  302. }
  303. for (const node of inputNodes) {
  304. for (const input of node.inputs) {
  305. for (const arg of input.arguments) {
  306. if (arg.name === preprocessingInput) {
  307. arg.name = preprocessorOutput;
  308. }
  309. }
  310. }
  311. }
  312. }
  313. }
  314. _loadModel(model, scope, group, weights) {
  315. this._groups = this._groups | (group.length > 0 ? true : false);
  316. const description = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : '';
  317. switch (model.Type) {
  318. case 'neuralNetworkClassifier': {
  319. const neuralNetworkClassifier = model.neuralNetworkClassifier;
  320. for (const layer of neuralNetworkClassifier.layers) {
  321. this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
  322. }
  323. this._updateClassifierOutput(group, neuralNetworkClassifier);
  324. this._updatePreprocessing(scope, group, neuralNetworkClassifier.preprocessing);
  325. return 'Neural Network Classifier';
  326. }
  327. case 'neuralNetwork': {
  328. const neuralNetwork = model.neuralNetwork;
  329. for (const layer of neuralNetwork.layers) {
  330. this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
  331. }
  332. this._updatePreprocessing(scope, group, neuralNetwork.preprocessing);
  333. return 'Neural Network';
  334. }
  335. case 'neuralNetworkRegressor': {
  336. const neuralNetworkRegressor = model.neuralNetworkRegressor;
  337. for (const layer of neuralNetworkRegressor.layers) {
  338. this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
  339. }
  340. this._updatePreprocessing(scope, group, neuralNetworkRegressor);
  341. return 'Neural Network Regressor';
  342. }
  343. case 'pipeline': {
  344. for (let i = 0; i < model.pipeline.models.length; i++) {
  345. this._loadModel(model.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipeline[' + i.toString() + ']');
  346. }
  347. return 'Pipeline';
  348. }
  349. case 'pipelineClassifier': {
  350. for (let i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) {
  351. this._loadModel(model.pipelineClassifier.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineClassifier[' + i.toString() + ']');
  352. }
  353. return 'Pipeline Classifier';
  354. }
  355. case 'pipelineRegressor': {
  356. for (let i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) {
  357. this._loadModel(model.pipelineRegressor.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineRegressor[' + i.toString() + ']');
  358. }
  359. return 'Pipeline Regressor';
  360. }
  361. case 'glmClassifier': {
  362. this._createNode(scope, group, 'glmClassifier', null, description,
  363. {
  364. classEncoding: model.glmClassifier.classEncoding,
  365. offset: model.glmClassifier.offset,
  366. weights: model.glmClassifier.weights
  367. },
  368. [ model.description.input[0].name ],
  369. [ model.description.predictedProbabilitiesName ]);
  370. this._updateClassifierOutput(group, model.glmClassifier);
  371. return 'Generalized Linear Classifier';
  372. }
  373. case 'glmRegressor': {
  374. this._createNode(scope, group, 'glmRegressor', null, description,
  375. model.glmRegressor,
  376. [ model.description.input[0].name ],
  377. [ model.description.output[0].name ]);
  378. return 'Generalized Linear Regressor';
  379. }
  380. case 'dictVectorizer': {
  381. this._createNode(scope, group, 'dictVectorizer', null, description,
  382. model.dictVectorizer,
  383. [ model.description.input[0].name ],
  384. [ model.description.output[0].name ]);
  385. return 'Dictionary Vectorizer';
  386. }
  387. case 'featureVectorizer': {
  388. this._createNode(scope, group, 'featureVectorizer', null, description,
  389. model.featureVectorizer,
  390. coreml.Graph._formatFeatureDescriptionList(model.description.input),
  391. [ model.description.output[0].name ]);
  392. return 'Feature Vectorizer';
  393. }
  394. case 'treeEnsembleClassifier': {
  395. this._createNode(scope, group, 'treeEnsembleClassifier', null, description,
  396. model.treeEnsembleClassifier.treeEnsemble,
  397. [ model.description.input[0].name ],
  398. [ model.description.output[0].name ]);
  399. this._updateClassifierOutput(group, model.treeEnsembleClassifier);
  400. return 'Tree Ensemble Classifier';
  401. }
  402. case 'treeEnsembleRegressor': {
  403. this._createNode(scope, group, 'treeEnsembleRegressor', null, description,
  404. model.treeEnsembleRegressor.treeEnsemble,
  405. [ model.description.input[0].name ],
  406. [ model.description.output[0].name ]);
  407. return 'Tree Ensemble Regressor';
  408. }
  409. case 'supportVectorClassifier': {
  410. this._createNode(scope, group, 'supportVectorClassifier', null, description,
  411. {
  412. coefficients: model.supportVectorClassifier.coefficients,
  413. denseSupportVectors: model.supportVectorClassifier.denseSupportVectors,
  414. kernel: model.supportVectorClassifier.kernel,
  415. numberOfSupportVectorsPerClass: model.supportVectorClassifier.numberOfSupportVectorsPerClass,
  416. probA: model.supportVectorClassifier.probA,
  417. probB: model.supportVectorClassifier.probB,
  418. rho: model.supportVectorClassifier.rho,
  419. supportVectors: model.supportVectorClassifier.supportVectors
  420. },
  421. [ model.description.input[0].name ],
  422. [ model.description.output[0].name ]);
  423. this._updateClassifierOutput(group, model.supportVectorClassifier);
  424. return 'Support Vector Classifier';
  425. }
  426. case 'supportVectorRegressor': {
  427. this._createNode(scope, group, 'supportVectorRegressor', null, description,
  428. {
  429. coefficients: model.supportVectorRegressor.coefficients,
  430. kernel: model.supportVectorRegressor.kernel,
  431. rho: model.supportVectorRegressor.rho,
  432. supportVectors: model.supportVectorRegressor.supportVectors
  433. },
  434. [ model.description.input[0].name ],
  435. [ model.description.output[0].name ]);
  436. return 'Support Vector Regressor';
  437. }
  438. case 'arrayFeatureExtractor': {
  439. this._createNode(scope, group, 'arrayFeatureExtractor', null, description,
  440. { extractIndex: model.arrayFeatureExtractor.extractIndex },
  441. [ model.description.input[0].name ],
  442. [ model.description.output[0].name ]);
  443. return 'Array Feature Extractor';
  444. }
  445. case 'oneHotEncoder': {
  446. const categoryType = model.oneHotEncoder.CategoryType;
  447. const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse };
  448. oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType];
  449. this._createNode(scope, group, 'oneHotEncoder', null, description,
  450. oneHotEncoderParams,
  451. [ model.description.input[0].name ],
  452. [ model.description.output[0].name ]);
  453. return 'One Hot Encoder';
  454. }
  455. case 'imputer': {
  456. const imputedValue = model.imputer.ImputedValue;
  457. const replaceValue = model.imputer.ReplaceValue;
  458. const imputerParams = {};
  459. imputerParams[imputedValue] = model.imputer[imputedValue];
  460. imputerParams[replaceValue] = model.imputer[replaceValue];
  461. this._createNode(scope, group, 'oneHotEncoder', null, description,
  462. imputerParams,
  463. [ model.description.input[0].name ],
  464. [ model.description.output[0].name ]);
  465. return 'Imputer';
  466. }
  467. case 'normalizer': {
  468. this._createNode(scope, group, 'normalizer', null, description,
  469. model.normalizer,
  470. [ model.description.input[0].name ],
  471. [ model.description.output[0].name ]);
  472. return 'Normalizer';
  473. }
  474. case 'wordTagger': {
  475. this._createNode(scope, group, 'wordTagger', null, description,
  476. model.wordTagger,
  477. [ model.description.input[0].name ],
  478. [
  479. model.wordTagger.tokensOutputFeatureName,
  480. model.wordTagger.tokenTagsOutputFeatureName,
  481. model.wordTagger.tokenLocationsOutputFeatureName,
  482. model.wordTagger.tokenLengthsOutputFeatureName
  483. ]);
  484. return 'Word Tagger';
  485. }
  486. case 'textClassifier': {
  487. this._createNode(scope, group, 'textClassifier', null, description,
  488. model.textClassifier,
  489. [ model.description.input[0].name ],
  490. [ model.description.output[0].name ]);
  491. return 'Text Classifier';
  492. }
  493. case 'nonMaximumSuppression': {
  494. const nonMaximumSuppressionParams = {
  495. pickTop: model.nonMaximumSuppression.pickTop,
  496. stringClassLabels: model.nonMaximumSuppression.stringClassLabels,
  497. iouThreshold: model.nonMaximumSuppression.iouThreshold,
  498. confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold
  499. };
  500. this._createNode(scope, group, 'nonMaximumSuppression', null, description,
  501. nonMaximumSuppressionParams,
  502. [
  503. model.nonMaximumSuppression.confidenceInputFeatureName,
  504. model.nonMaximumSuppression.coordinatesInputFeatureName,
  505. model.nonMaximumSuppression.iouThresholdInputFeatureName,
  506. model.nonMaximumSuppression.confidenceThresholdInputFeatureName,
  507. ],
  508. [
  509. model.nonMaximumSuppression.confidenceOutputFeatureName,
  510. model.nonMaximumSuppression.coordinatesOutputFeatureName
  511. ]);
  512. return 'Non Maximum Suppression';
  513. }
  514. case 'visionFeaturePrint': {
  515. const visionFeaturePrintParams = {
  516. scene: model.visionFeaturePrint.scene
  517. };
  518. this._createNode(scope, group, 'visionFeaturePrint', null, description,
  519. visionFeaturePrintParams,
  520. [ model.description.input[0].name ],
  521. [ model.description.output[0].name ]);
  522. return 'Vision Feature Print';
  523. }
  524. case 'soundAnalysisPreprocessing': {
  525. this._createNode(scope, group, 'soundAnalysisPreprocessing', null, description,
  526. model.soundAnalysisPreprocessing,
  527. [ model.description.input[0].name ],
  528. [ model.description.output[0].name ]);
  529. return 'Sound Analysis Preprocessing';
  530. }
  531. case 'kNearestNeighborsClassifier': {
  532. this._createNode(scope, group, 'kNearestNeighborsClassifier', null, description,
  533. model.kNearestNeighborsClassifier,
  534. [ model.description.input[0].name ],
  535. [ model.description.output[0].name ]);
  536. this._updateClassifierOutput(group, model.kNearestNeighborsClassifier);
  537. return 'Nearest Neighbors Classifier';
  538. }
  539. case 'itemSimilarityRecommender': {
  540. this._createNode(scope, group, 'itemSimilarityRecommender', null, description,
  541. {
  542. itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector,
  543. itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities
  544. },
  545. model.description.input.map((feature) => feature.name),
  546. model.description.output.map((feature) => feature.name));
  547. return 'Item Similarity Recommender';
  548. }
  549. case 'linkedModel': {
  550. this._createNode(scope, group, 'linkedModel', null, description,
  551. model.linkedModel.linkedModelFile,
  552. [ model.description.input[0].name ],
  553. [ model.description.output[0].name ]);
  554. return 'Linked Model';
  555. }
  556. case 'customModel': {
  557. this._createNode(scope, group, 'customModel', null, description,
  558. { className: model.customModel.className, parameters: model.customModel.parameters },
  559. [ model.description.input[0].name ],
  560. [ model.description.output[0].name ]);
  561. return 'customModel';
  562. }
  563. case 'mlProgram': {
  564. return this._loadProgram(model.mlProgram, scope, group, weights);
  565. }
  566. }
  567. throw new coreml.Error("Unknown model type '" + JSON.stringify(Object.keys(model)) + "'.");
  568. }
  569. _loadProgram(program, scope, group, weights) {
  570. // TODO: need to handle functions other than main?
  571. const main = program.functions.main;
  572. // TODO: need to handle more than one block specialization?
  573. const block = main.block_specializations.CoreML5;
  574. const convertValue = (value) => {
  575. switch (value.value) {
  576. case 'immediateValue': {
  577. const tensor = value.immediateValue.tensor;
  578. let values = null;
  579. switch (tensor.value) {
  580. case 'ints':
  581. values = tensor.ints.values;
  582. break;
  583. case 'strings':
  584. values = tensor.strings.values;
  585. break;
  586. case 'bools':
  587. values = tensor.bools.values;
  588. break;
  589. case 'floats':
  590. values = tensor.floats.values;
  591. break;
  592. default:
  593. throw new coreml.Error("Unsupported tensor value '" + tensor.value + "'.");
  594. }
  595. return values;
  596. }
  597. case 'blobFileValue': {
  598. const type = coreml.Utility.type(value.type);
  599. const blob = value.blobFileValue;
  600. const offset = blob.offset.toNumber();
  601. const file = blob.fileName;
  602. let data = null;
  603. const stream = weights.get(file);
  604. if (stream) {
  605. stream.seek(offset);
  606. const buffer = stream.read(32);
  607. const reader = new coreml.BinaryReader(buffer);
  608. const signature = reader.uint32();
  609. if (signature == 0xdeadbeef) {
  610. reader.uint32(); // dataType
  611. const size = reader.uint64();
  612. stream.seek(reader.uint64());
  613. const length = (type.shape.dimensions || []).reduce((a, b) => a * b, 1);
  614. switch (type.dataType) {
  615. case 'float32': {
  616. const buffer = stream.read(size);
  617. data = new Float32Array(buffer.buffer, buffer.byteOffset, length).slice();
  618. break;
  619. }
  620. default:
  621. throw new coreml.Error("Unsupported blob data type '" + type.dataType + "'.");
  622. }
  623. }
  624. }
  625. return new coreml.Tensor('Blob', type, data);
  626. }
  627. }
  628. throw new coreml.Error("Unsupported value '" + value.value + "'.");
  629. };
  630. const args = new Map();
  631. const arg = (name) => {
  632. if (!args.has(name)) {
  633. args.set(name, { name: name, to: [], from: [] });
  634. }
  635. return args.get(name);
  636. };
  637. const operations = block.operations.map((op) => {
  638. const operation = {
  639. type: op.type,
  640. attributes: {}
  641. };
  642. for (const key of Object.keys(op.attributes)) {
  643. operation.attributes[key] = convertValue(op.attributes[key]);
  644. }
  645. operation.inputs = Object.keys(op.inputs).map((key) => {
  646. const input = op.inputs[key];
  647. const args = input.arguments.map((argument) => {
  648. if (argument.name) {
  649. const value = arg(argument.name);
  650. value.to.push(operation);
  651. return value;
  652. }
  653. return { value: argument.value };
  654. });
  655. return {
  656. name: key,
  657. arguments: args
  658. };
  659. });
  660. operation.outputs = op.outputs.map((output) => {
  661. const value = arg(output.name);
  662. value.type = coreml.Utility.type(output.type);
  663. value.from.push(operation);
  664. return {
  665. name: 'output',
  666. arguments: [ value ]
  667. };
  668. });
  669. return operation;
  670. });
  671. for (const op of operations) {
  672. if (op.type === 'const' && op.inputs.length === 0 &&
  673. op.outputs.length === 1 && op.outputs[0].arguments.length === 1) {
  674. const argument = op.outputs[0].arguments[0];
  675. if (op.attributes && op.attributes.val) {
  676. argument.value = op.attributes.val;
  677. op.delete = true;
  678. }
  679. }
  680. }
  681. for (const op of operations) {
  682. if (op.delete) {
  683. continue;
  684. }
  685. op.inputs = op.inputs.filter((input) => {
  686. if (input.arguments.length !== 1) {
  687. return true;
  688. }
  689. const argument = input.arguments[0];
  690. if (argument.value === undefined || argument.value instanceof coreml.Tensor) {
  691. return true;
  692. }
  693. op.attributes[input.name] = argument.value;
  694. return false;
  695. });
  696. }
  697. const tensors = new Map();
  698. const tensor = (arg) => {
  699. if (!tensors.has(arg.name)) {
  700. tensors.set(arg.name, new coreml.Argument(arg.name, arg.type, null, arg.value));
  701. }
  702. return tensors.get(arg.name);
  703. };
  704. for (const op of operations) {
  705. if (op.delete) {
  706. continue;
  707. }
  708. op.inputs = op.inputs.map((input) => new coreml.Parameter(input.name, true, input.arguments.map((argument) => tensor(argument))));
  709. op.outputs = op.outputs.map((output) => new coreml.Parameter(output.name, true, output.arguments.map((argument) => tensor(argument))));
  710. }
  711. for (const op of operations.filter((op) => !op.delete)) {
  712. const type = 'program:' + op.type;
  713. const metadata = this._metadata.type(type);
  714. if (metadata && Array.isArray(metadata.inputs)) {
  715. let index = 1;
  716. const map = new Map(metadata.inputs.map((input) => [ input.name, index++ ]));
  717. op.inputs.sort((a, b) => (map.get(a.name) || map.size) - (map.get(b.name) || map.size));
  718. }
  719. const node = new coreml.Node(this._metadata, group, type, null, null, op.attributes, op.inputs, op.outputs);
  720. this._nodes.push(node);
  721. }
  722. return 'ML Program';
  723. }
  724. _createNode(scope, group, type, name, description, data, inputs, outputs, outputTypes) {
  725. inputs = inputs.map((input) => scope[input] ? scope[input].argument : input);
  726. outputs = outputs.map((output) => {
  727. if (scope[output]) {
  728. scope[output].counter++;
  729. const next = output + '\n' + scope[output].counter.toString(); // custom argument id
  730. scope[output].argument = next;
  731. return next;
  732. }
  733. scope[output] = {
  734. argument: output,
  735. counter: 0
  736. };
  737. return output;
  738. });
  739. const initializers = [];
  740. const attributes = {};
  741. if (data) {
  742. const map = this._initialize(type, data, initializers);
  743. for (const key of Object.keys(data)) {
  744. if (map[key]) {
  745. continue;
  746. }
  747. attributes[key] = data[key];
  748. }
  749. }
  750. const inputParameters = this._metadata.getInputs(type, inputs).map((input) => {
  751. return new coreml.Parameter(input.name, true, input.arguments.map((argument) => {
  752. return new coreml.Argument(argument.name, argument.type, null, null);
  753. }));
  754. });
  755. inputParameters.push(...initializers);
  756. const outputParameters = outputs.map((output, index) => {
  757. const name = this._metadata.getOutputName(type, index);
  758. const outputType = outputTypes ? outputTypes[index] : null;
  759. return new coreml.Parameter(name, true, [ new coreml.Argument(output, outputType, null, null) ]);
  760. });
  761. const node = new coreml.Node(this._metadata, group, type, name, description, attributes, inputParameters, outputParameters);
  762. this._nodes.push(node);
  763. return node;
  764. }
  765. _initializer(type, initializers, kind, name, shape, data) {
  766. let dataType = '?';
  767. let quantization = null;
  768. let values = null;
  769. if (data) {
  770. if (data.floatValue && data.floatValue.length > 0) {
  771. values = data.floatValue;
  772. dataType = 'float32';
  773. }
  774. else if (data.float16Value && data.float16Value.length > 0) {
  775. values = data.float16Value; // byte[]
  776. dataType = 'float16';
  777. }
  778. else if (data.rawValue && data.rawValue.length > 0) {
  779. if (data.quantization) {
  780. values = data.rawValue;
  781. dataType = 'uint' + data.quantization.numberOfBits.toString();
  782. }
  783. else {
  784. shape = [];
  785. }
  786. }
  787. quantization = data.quantization || null;
  788. }
  789. const tensorType = new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  790. const tensor = new coreml.Tensor(kind, tensorType, values, quantization);
  791. const argument = new coreml.Argument('', null, null, tensor);
  792. const visible = this._metadata.visible(type, name);
  793. initializers.push(new coreml.Parameter(name, visible, [ argument ]));
  794. }
  795. _initialize(type, data, initializers) {
  796. switch (type) {
  797. case 'convolution': {
  798. const weightsShape = [ data.outputChannels, data.kernelChannels, data.kernelSize[0], data.kernelSize[1] ];
  799. if (data.isDeconvolution) {
  800. weightsShape[0] = data.kernelChannels;
  801. weightsShape[1] = Math.floor(data.outputChannels / (data.nGroups != 0 ? data.nGroups : 1));
  802. }
  803. this._initializer(type, initializers, 'Weights', 'weights', weightsShape, data.weights);
  804. if (data.hasBias) {
  805. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputChannels ], data.bias);
  806. }
  807. return { 'weights': true, 'bias': data.hasBias };
  808. }
  809. case 'innerProduct':
  810. this._initializer(type, initializers, 'Weights', 'weights', [ data.outputChannels, data.inputChannels ], data.weights);
  811. if (data.hasBias) {
  812. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputChannels ], data.bias);
  813. }
  814. return { 'weights': true, 'bias': data.hasBias };
  815. case 'batchnorm':
  816. this._initializer(type, initializers, 'Weights', 'gamma', [ data.channels ], data.gamma);
  817. this._initializer(type, initializers, 'Weights', 'beta', [ data.channels ], data.beta);
  818. if (data.mean) {
  819. this._initializer(type, initializers, 'Weights', 'mean', [ data.channels ], data.mean);
  820. }
  821. if (data.variance) {
  822. this._initializer(type, initializers, 'Weights', 'variance', [ data.channels ], data.variance);
  823. }
  824. return { 'gamma': true, 'beta': true, 'mean': true, 'variance': true };
  825. case 'embedding':
  826. this._initializer(type, initializers, 'Weights', 'weights', [ data.inputDim, data.outputChannels ], data.weights);
  827. return { 'weights': true };
  828. case 'loadConstant':
  829. case 'loadConstantND':
  830. this._initializer(type, initializers, 'Weights', 'data', data.shape, data.data);
  831. return { 'data': true };
  832. case 'scale':
  833. this._initializer(type, initializers, 'Weights', 'scale', data.shapeScale, data.scale);
  834. if (data.hasBias) {
  835. this._initializer(type, initializers, 'Weights', 'bias', data.shapeBias, data.bias);
  836. }
  837. return { 'scale': true, 'bias': data.hasBias };
  838. case 'bias':
  839. this._initializer(type, initializers, 'Weights', 'bias', data.shape, data.bias);
  840. return { 'bias': true };
  841. case 'simpleRecurrent':
  842. this._initializer(type, initializers, 'Weights', 'weights', [ data.outputVectorSize, data.inputVectorSize ], data.weightMatrix);
  843. this._initializer(type, initializers, 'Weights', 'recurrent', [ data.outputVectorSize, data.inputVectorSize ], data.recursionMatrix);
  844. if (data.hasBiasVectors) {
  845. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputVectorSize ], data.biasVector);
  846. }
  847. return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
  848. case 'gru': {
  849. const recursionMatrixShape = [ data.outputVectorSize, data.outputVectorSize ];
  850. const weightMatrixShape = [ data.outputVectorSize, data.inputVectorSize ];
  851. const biasVectorShape = [ data.outputVectorSize ];
  852. this._initializer(type, initializers, 'Weights', 'updateGateWeightMatrix', weightMatrixShape, data.updateGateWeightMatrix);
  853. this._initializer(type, initializers, 'Weights', 'resetGateWeightMatrix', weightMatrixShape, data.resetGateWeightMatrix);
  854. this._initializer(type, initializers, 'Weights', 'outputGateWeightMatrix', weightMatrixShape, data.outputGateWeightMatrix);
  855. this._initializer(type, initializers, 'Weights', 'updateGateRecursionMatrix', recursionMatrixShape, data.updateGateRecursionMatrix);
  856. this._initializer(type, initializers, 'Weights', 'resetGateRecursionMatrix', recursionMatrixShape, data.resetGateRecursionMatrix);
  857. this._initializer(type, initializers, 'Weights', 'outputGateRecursionMatrix', recursionMatrixShape, data.outputGateRecursionMatrix);
  858. if (data.hasBiasVectors) {
  859. this._initializer(type, initializers, 'Weights', 'updateGateBiasVector', biasVectorShape, data.updateGateBiasVector);
  860. this._initializer(type, initializers, 'Weights', 'resetGateBiasVector', biasVectorShape, data.resetGateBiasVector);
  861. this._initializer(type, initializers, 'Weights', 'outputGateBiasVector', biasVectorShape, data.outputGateBiasVector);
  862. }
  863. return {
  864. 'updateGateWeightMatrix': true, 'resetGateWeightMatrix': true, 'outputGateWeightMatrix': true,
  865. 'updateGateRecursionMatrix': true, 'resetGateRecursionMatrix': true, 'outputGateRecursionMatrix': true,
  866. 'updateGateBiasVector': data.hasBiasVectors, 'resetGateBiasVector': data.hasBiasVectors, 'outputGateBiasVector': data.hasBiasVectors
  867. };
  868. }
  869. case 'uniDirectionalLSTM':
  870. case 'biDirectionalLSTM': {
  871. const count = (type == 'uniDirectionalLSTM') ? 1 : 2;
  872. const matrixShape = [ data.outputVectorSize, data.inputVectorSize ];
  873. const vectorShape = [ data.outputVectorSize ];
  874. for (let i = 0; i < count; i++) {
  875. const weights = count == 1 ? data.weightParams : data.weightParams[i];
  876. const suffix = (i == 0) ? '' : '_rev';
  877. this._initializer(type, initializers, 'Weights', 'inputGateWeightMatrix' + suffix, matrixShape, weights.inputGateWeightMatrix);
  878. this._initializer(type, initializers, 'Weights', 'forgetGateWeightMatrix' + suffix, matrixShape, weights.forgetGateWeightMatrix);
  879. this._initializer(type, initializers, 'Weights', 'blockInputWeightMatrix' + suffix, matrixShape, weights.blockInputWeightMatrix);
  880. this._initializer(type, initializers, 'Weights', 'outputGateWeightMatrix' + suffix, matrixShape, weights.outputGateWeightMatrix);
  881. this._initializer(type, initializers, 'Weights', 'inputGateRecursionMatrix' + suffix, matrixShape, weights.inputGateRecursionMatrix);
  882. this._initializer(type, initializers, 'Weights', 'forgetGateRecursionMatrix' + suffix, matrixShape,weights.forgetGateRecursionMatrix);
  883. this._initializer(type, initializers, 'Weights', 'blockInputRecursionMatrix' + suffix, matrixShape, weights.blockInputRecursionMatrix);
  884. this._initializer(type, initializers, 'Weights', 'outputGateRecursionMatrix' + suffix, matrixShape, weights.outputGateRecursionMatrix);
  885. if (data.params.hasBiasVectors) {
  886. this._initializer(type, initializers, 'Weights', 'inputGateBiasVector' + suffix, vectorShape, weights.inputGateBiasVector);
  887. this._initializer(type, initializers, 'Weights', 'forgetGateBiasVector' + suffix, vectorShape, weights.forgetGateBiasVector);
  888. this._initializer(type, initializers, 'Weights', 'blockInputBiasVector' + suffix, vectorShape, weights.blockInputBiasVector);
  889. this._initializer(type, initializers, 'Weights', 'outputGateBiasVector' + suffix, vectorShape, weights.outputGateBiasVector);
  890. }
  891. if (data.params.hasPeepholeVectors) {
  892. this._initializer(type, initializers, 'Weights', 'inputGatePeepholeVector' + suffix, vectorShape, weights.inputGatePeepholeVector);
  893. this._initializer(type, initializers, 'Weights', 'forgetGatePeepholeVector' + suffix, vectorShape, weights.forgetGatePeepholeVector);
  894. this._initializer(type, initializers, 'Weights', 'outputGatePeepholeVector' + suffix, vectorShape, weights.outputGatePeepholeVector);
  895. }
  896. }
  897. return { 'weightParams': true };
  898. }
  899. case 'dictVectorizer':
  900. data.stringToIndex = this._convertVector(data.stringToIndex);
  901. return {};
  902. case 'wordTagger':
  903. data.modelParameterData = Array.from(data.modelParameterData);
  904. data.stringTags = this._convertVector(data.stringTags);
  905. return { tokensOutputFeatureName: true, tokenTagsOutputFeatureName: true, tokenLengthsOutputFeatureName: true, tokenLocationsOutputFeatureName: true };
  906. case 'textClassifier':
  907. data.modelParameterData = Array.from(data.modelParameterData);
  908. data.stringClassLabels = this._convertVector(data.stringClassLabels);
  909. return {};
  910. case 'nonMaximumSuppression':
  911. data.stringClassLabels = this._convertVector(data.stringClassLabels);
  912. return {};
  913. }
  914. return {};
  915. }
  916. _convertVector(value) {
  917. if (value && Object.keys(value).length == 1 && value.vector) {
  918. return value.vector;
  919. }
  920. return value;
  921. }
  922. static _formatFeatureDescriptionList(list) {
  923. return list.map((item) => item.name);
  924. }
  925. };
  926. coreml.Parameter = class {
  927. constructor(name, visible, args) {
  928. this._name = name;
  929. this._visible = visible;
  930. this._arguments = args;
  931. }
  932. get name() {
  933. return this._name;
  934. }
  935. get visible() {
  936. return this._visible;
  937. }
  938. get arguments() {
  939. return this._arguments;
  940. }
  941. };
  942. coreml.Argument = class {
  943. constructor(name, type, description, initializer) {
  944. if (typeof name !== 'string') {
  945. throw new coreml.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  946. }
  947. this._name = name;
  948. this._type = type;
  949. this._description = description || null;
  950. this._initializer = initializer || null;
  951. }
  952. get name() {
  953. return this._name;
  954. }
  955. set name(value) {
  956. this._name = value;
  957. }
  958. get type() {
  959. if (this._initializer) {
  960. return this._initializer.type;
  961. }
  962. return this._type;
  963. }
  964. get description() {
  965. return this._description;
  966. }
  967. get quantization() {
  968. if (this._initializer) {
  969. return this._initializer.quantization;
  970. }
  971. return null;
  972. }
  973. get initializer() {
  974. return this._initializer;
  975. }
  976. };
  977. coreml.Node = class {
  978. constructor(metadata, group, type, name, description, attributes, inputs, outputs) {
  979. if (!type) {
  980. throw new Error('Undefined node type.');
  981. }
  982. if (group) {
  983. this._group = group;
  984. }
  985. this._type = Object.assign({}, metadata.type(type) || { name: type });
  986. this._type.name = type.split(':').pop();
  987. this._name = name || '';
  988. this._description = description || '';
  989. this._inputs = inputs;
  990. this._outputs = outputs;
  991. this._attributes = [];
  992. if (attributes) {
  993. for (const key of Object.keys(attributes)) {
  994. const schema = metadata.attribute(type, key);
  995. const value = attributes[key];
  996. const attribute = new coreml.Attribute(schema, key, value);
  997. this._attributes.push(attribute);
  998. }
  999. }
  1000. }
  1001. get type() {
  1002. return this._type;
  1003. }
  1004. get name() {
  1005. return this._name;
  1006. }
  1007. get description() {
  1008. return this._description;
  1009. }
  1010. get metadata() {
  1011. return this._metadata;
  1012. }
  1013. get group() {
  1014. return this._group ? this._group : null;
  1015. }
  1016. get inputs() {
  1017. return this._inputs;
  1018. }
  1019. get outputs() {
  1020. return this._outputs;
  1021. }
  1022. get attributes() {
  1023. return this._attributes;
  1024. }
  1025. };
  1026. coreml.Attribute = class {
  1027. constructor(schema, name, value) {
  1028. this._name = name;
  1029. this._value = value;
  1030. if (schema) {
  1031. if (schema.type) {
  1032. this._type = schema.type;
  1033. }
  1034. if (this._type && coreml.proto) {
  1035. this._value = coreml.Utility.enum(this._type, this._value);
  1036. }
  1037. if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  1038. this._visible = false;
  1039. }
  1040. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  1041. if (Array.isArray(value)) {
  1042. value = value.map((item) => item.toNumber());
  1043. }
  1044. if (JSON.stringify(schema.default) == JSON.stringify(value)) {
  1045. this._visible = false;
  1046. }
  1047. }
  1048. }
  1049. }
  1050. get name() {
  1051. return this._name;
  1052. }
  1053. get type() {
  1054. return this._type;
  1055. }
  1056. get value() {
  1057. return this._value;
  1058. }
  1059. get visible() {
  1060. return this._visible == false ? false : true;
  1061. }
  1062. };
  1063. coreml.Tensor = class {
  1064. constructor(kind, type, data, quantization) {
  1065. this._kind = kind;
  1066. this._type = type;
  1067. this._data = data;
  1068. this._quantization = quantization;
  1069. }
  1070. get kind() {
  1071. return this._kind;
  1072. }
  1073. get type() {
  1074. return this._type;
  1075. }
  1076. get quantization() {
  1077. if (this._quantization) {
  1078. if (this._quantization.lookupTableQuantization &&
  1079. this._quantization.lookupTableQuantization.floatValue &&
  1080. this._quantization.lookupTableQuantization.floatValue.length > 0) {
  1081. const map = [];
  1082. for (const key of Object.keys(this._quantization.lookupTableQuantization.floatValue)) {
  1083. map.push(key.toString() + ' = ' + this._quantization.lookupTableQuantization.floatValue[key].toString());
  1084. }
  1085. return map.join('; ');
  1086. }
  1087. return '?';
  1088. }
  1089. return null;
  1090. }
  1091. get state() {
  1092. return this._context().state;
  1093. }
  1094. get value() {
  1095. const context = this._context();
  1096. if (context.state) {
  1097. return null;
  1098. }
  1099. context.limit = Number.MAX_SAFE_INTEGER;
  1100. return this._decode(context, 0);
  1101. }
  1102. toString() {
  1103. const context = this._context();
  1104. if (context.state) {
  1105. return '';
  1106. }
  1107. context.limit = 10000;
  1108. const value = this._decode(context, 0);
  1109. return JSON.stringify(value, null, 4);
  1110. }
  1111. _context() {
  1112. const context = {};
  1113. context.state = null;
  1114. context.index = 0;
  1115. context.count = 0;
  1116. context.dataType = this._type.dataType;
  1117. context.dimensions = this._type.shape.dimensions;
  1118. if (!this._data) {
  1119. context.state = 'Tensor data is empty.';
  1120. return context;
  1121. }
  1122. switch (context.dataType) {
  1123. case 'float32':
  1124. context.data = this._data;
  1125. break;
  1126. case 'float16':
  1127. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  1128. break;
  1129. default:
  1130. if (this._quantization) {
  1131. context.dataType = 'quantization';
  1132. context.bits = this._quantization.numberOfBits.toNumber();
  1133. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  1134. }
  1135. else {
  1136. context.state = 'Tensor data type is not implemented.';
  1137. }
  1138. break;
  1139. }
  1140. return context;
  1141. }
  1142. _decode(context, dimension) {
  1143. const results = [];
  1144. const size = context.dimensions[dimension];
  1145. if (dimension == context.dimensions.length - 1) {
  1146. for (let i = 0; i < size; i++) {
  1147. if (context.count > context.limit) {
  1148. results.push('...');
  1149. return results;
  1150. }
  1151. switch (context.dataType) {
  1152. case 'float32':
  1153. results.push(this._data[context.index]);
  1154. context.index++;
  1155. break;
  1156. case 'float16':
  1157. results.push(context.data.getFloat16(context.index, true));
  1158. context.index += 2;
  1159. break;
  1160. case 'quantization':
  1161. results.push(context.data.getBits(context.index, context.bits));
  1162. context.index++;
  1163. break;
  1164. }
  1165. context.count++;
  1166. }
  1167. }
  1168. else {
  1169. for (let j = 0; j < size; j++) {
  1170. if (context.count > context.limit) {
  1171. results.push('...');
  1172. return results;
  1173. }
  1174. results.push(this._decode(context, dimension + 1));
  1175. }
  1176. }
  1177. return results;
  1178. }
  1179. };
  1180. coreml.TensorType = class {
  1181. constructor(dataType, shape) {
  1182. this._dataType = dataType;
  1183. this._shape = shape || new coreml.TensorShape([]);
  1184. }
  1185. get dataType() {
  1186. return this._dataType;
  1187. }
  1188. get shape() {
  1189. return this._shape;
  1190. }
  1191. toString() {
  1192. return this.dataType + this._shape.toString();
  1193. }
  1194. };
  1195. coreml.TensorShape = class {
  1196. constructor(dimensions) {
  1197. this._dimensions = dimensions;
  1198. }
  1199. get dimensions() {
  1200. return this._dimensions;
  1201. }
  1202. toString() {
  1203. if (!this._dimensions || this._dimensions.length == 0) {
  1204. return '';
  1205. }
  1206. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  1207. }
  1208. };
  1209. coreml.MapType = class {
  1210. constructor(keyType, valueType) {
  1211. this._keyType = keyType;
  1212. this._valueType = valueType;
  1213. }
  1214. get keyType() {
  1215. return this._keyType;
  1216. }
  1217. get valueType() {
  1218. return this._valueType;
  1219. }
  1220. toString() {
  1221. return 'map<' + this._keyType + ',' + this._valueType.toString() + '>';
  1222. }
  1223. };
  1224. coreml.ImageType = class {
  1225. constructor(colorSpace, width, height) {
  1226. this._colorSpace = '?';
  1227. switch (colorSpace) {
  1228. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE:
  1229. this._colorSpace = 'Grayscale';
  1230. break;
  1231. case coreml.proto.ImageFeatureType.ColorSpace.RGB:
  1232. this._colorSpace = 'RGB';
  1233. break;
  1234. case coreml.proto.ImageFeatureType.ColorSpace.BGR:
  1235. this._colorSpace = 'BGR';
  1236. break;
  1237. }
  1238. this._width = width;
  1239. this._height = height;
  1240. }
  1241. toString() {
  1242. return 'image<' + this._colorSpace + ',' + this._width. toString() + 'x' + this._height.toString() + '>';
  1243. }
  1244. };
  1245. coreml.OptionalType = class {
  1246. constructor(type) {
  1247. this._type = type;
  1248. }
  1249. toString() {
  1250. return this._type.toString() + '?';
  1251. }
  1252. };
  1253. coreml.BinaryReader = class {
  1254. constructor(buffer) {
  1255. this._buffer = buffer;
  1256. this._position = 0;
  1257. this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  1258. }
  1259. skip(offset) {
  1260. const position = this._position;
  1261. this._position += offset;
  1262. if (this._position > this._length) {
  1263. throw new Error('Expected ' + (this._position - this._length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
  1264. }
  1265. return position;
  1266. }
  1267. uint32() {
  1268. const position = this.skip(4);
  1269. return this._dataView.getUint32(position, true);
  1270. }
  1271. uint64() {
  1272. const position = this.skip(8);
  1273. return this._dataView.getUint64(position, true).toNumber();
  1274. }
  1275. };
  1276. coreml.Utility = class {
  1277. static enum(name, value) {
  1278. let type = coreml.proto;
  1279. const parts = name.split('.');
  1280. while (type && parts.length > 0) {
  1281. type = type[parts.shift()];
  1282. }
  1283. if (type) {
  1284. coreml.Utility._enumKeyMap = coreml.Utility._enumKeyMap || new Map();
  1285. if (!coreml.Utility._enumKeyMap.has(name)) {
  1286. const map = new Map(Object.entries(type).map((pair) => [ pair[1], pair[0] ]));
  1287. coreml.Utility._enumKeyMap.set(name, map);
  1288. }
  1289. const map = coreml.Utility._enumKeyMap.get(name);
  1290. if (map.has(value)) {
  1291. return map.get(value);
  1292. }
  1293. }
  1294. return value;
  1295. }
  1296. static featureType(type) {
  1297. let result = '?';
  1298. if (type) {
  1299. switch (type.Type) {
  1300. case 'multiArrayType': {
  1301. let shape = new coreml.TensorShape([]);
  1302. if (type.multiArrayType.shape && type.multiArrayType.shape.length > 0) {
  1303. shape = new coreml.TensorShape(type.multiArrayType.shape);
  1304. }
  1305. let dataType = '?';
  1306. switch (type.multiArrayType.dataType) {
  1307. case coreml.proto.ArrayFeatureType.ArrayDataType.FLOAT32:
  1308. dataType = 'float32';
  1309. break;
  1310. case coreml.proto.ArrayFeatureType.ArrayDataType.INT32:
  1311. dataType = 'int32';
  1312. break;
  1313. case coreml.proto.ArrayFeatureType.ArrayDataType.DOUBLE:
  1314. dataType = 'float64';
  1315. break;
  1316. }
  1317. result = new coreml.TensorType(dataType, shape);
  1318. break;
  1319. }
  1320. case 'stringType': {
  1321. result = new coreml.TensorType('string');
  1322. break;
  1323. }
  1324. case 'doubleType': {
  1325. result = new coreml.TensorType('float64');
  1326. break;
  1327. }
  1328. case 'int64Type': {
  1329. result = new coreml.TensorType('int64');
  1330. break;
  1331. }
  1332. case 'dictionaryType': {
  1333. result = new coreml.MapType(type.dictionaryType.KeyType.replace('KeyType', ''), 'float64');
  1334. break;
  1335. }
  1336. case 'imageType': {
  1337. result = new coreml.ImageType(type.imageType.colorSpace, type.imageType.width, type.imageType.height);
  1338. break;
  1339. }
  1340. }
  1341. if (type.isOptional) {
  1342. result = new coreml.OptionalType(result);
  1343. }
  1344. }
  1345. return result;
  1346. }
  1347. static type(type) {
  1348. switch (type.type) {
  1349. case 'tensorType':
  1350. type = type.tensorType;
  1351. break;
  1352. case 'listType':
  1353. type = type.listType;
  1354. break;
  1355. default:
  1356. throw new coreml.Error("Unsupported value type '" + type.type + "'.");
  1357. }
  1358. if (!coreml.Utility._dataTypes) {
  1359. coreml.Utility._dataTypes = new Map();
  1360. const DataType = coreml.proto.MILSpec.DataType;
  1361. for (const pair of Object.entries(DataType)) {
  1362. if (pair[0] === 'UNUSED_TYPE') {
  1363. continue;
  1364. }
  1365. const name = pair[0] === 'bool' ? 'boolean' : pair[0].toLowerCase();
  1366. coreml.Utility._dataTypes.set(pair[1], name);
  1367. }
  1368. }
  1369. const shape = (type.dimensions.map(dim => dim.constant ? dim.constant.size : '?'));
  1370. const dataType = coreml.Utility._dataTypes.get(type.dataType);
  1371. if (dataType === null) {
  1372. throw new coreml.Error("Unsupported data type '" + type.dataType + "'.");
  1373. }
  1374. return new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  1375. }
  1376. };
  1377. coreml.Metadata = class {
  1378. static open(context) {
  1379. if (coreml.Metadata._metadata) {
  1380. return Promise.resolve(coreml.Metadata._metadata);
  1381. }
  1382. return context.request('coreml-metadata.json', 'utf-8', null).then((data) => {
  1383. coreml.Metadata._metadata = new coreml.Metadata(data);
  1384. return coreml.Metadata._metadata;
  1385. }).catch(() => {
  1386. coreml.Metadata._metadata = new coreml.Metadata(null);
  1387. return coreml.Metadata._metadata;
  1388. });
  1389. }
  1390. constructor(data) {
  1391. this._map = new Map();
  1392. this._attributeCache = new Map();
  1393. this._inputCache = new Map();
  1394. if (data) {
  1395. const metadata = JSON.parse(data);
  1396. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  1397. }
  1398. }
  1399. type(name) {
  1400. return this._map.get(name);
  1401. }
  1402. attribute(type, name) {
  1403. const key = type + ':' + name;
  1404. if (!this._attributeCache.has(key)) {
  1405. this._attributeCache.set(key, null);
  1406. const metadata = this.type(type);
  1407. if (metadata && Array.isArray(metadata.attributes) && metadata.attributes.length > 0) {
  1408. for (const attribute of metadata.attributes) {
  1409. this._attributeCache.set(type + ':' + attribute.name, attribute);
  1410. }
  1411. }
  1412. }
  1413. return this._attributeCache.get(key);
  1414. }
  1415. visible(type, name) {
  1416. const key = type + ':' + name;
  1417. if (!this._inputCache.has(key)) {
  1418. this._inputCache.set(key, null);
  1419. const metadata = this.type(type);
  1420. if (metadata && Array.isArray(metadata.inputs) && metadata.inputs.length > 0) {
  1421. for (const input of metadata.inputs) {
  1422. this._inputCache.set(type + ':' + input.name, input);
  1423. }
  1424. }
  1425. }
  1426. const input = this._inputCache.get(key);
  1427. if (input) {
  1428. return input.visible === false ? false : true;
  1429. }
  1430. return true;
  1431. }
  1432. getInputs(type, inputs) {
  1433. const results = [];
  1434. const schema = this._map.get(type);
  1435. let index = 0;
  1436. while (index < inputs.length) {
  1437. const result = { arguments: [] };
  1438. let count = 1;
  1439. let name = null;
  1440. if (schema && schema.inputs) {
  1441. if (index < schema.inputs.length) {
  1442. const input = schema.inputs[index];
  1443. name = input.name;
  1444. if (schema.inputs[index].option == 'variadic') {
  1445. count = inputs.length - index;
  1446. }
  1447. }
  1448. }
  1449. else {
  1450. if (index == 0) {
  1451. name = 'input';
  1452. }
  1453. }
  1454. result.name = name ? name : '(' + index.toString() + ')';
  1455. const array = inputs.slice(index, index + count);
  1456. for (let j = 0; j < array.length; j++) {
  1457. result.arguments.push({ name: array[j] });
  1458. }
  1459. index += count;
  1460. results.push(result);
  1461. }
  1462. return results;
  1463. }
  1464. getOutputName(type, index) {
  1465. const schema = this._map.get(type);
  1466. if (schema) {
  1467. const outputs = schema.outputs;
  1468. if (outputs && index < outputs.length) {
  1469. const output = outputs[index];
  1470. if (output) {
  1471. const name = output.name;
  1472. if (name) {
  1473. return name;
  1474. }
  1475. }
  1476. }
  1477. }
  1478. if (index == 0) {
  1479. return 'output';
  1480. }
  1481. return '(' + index.toString() + ')';
  1482. }
  1483. };
  1484. coreml.Error = class extends Error {
  1485. constructor(message) {
  1486. super(message);
  1487. this.name = 'Error loading Core ML model.';
  1488. }
  1489. };
  1490. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1491. module.exports.ModelFactory = coreml.ModelFactory;
  1492. }