coreml.js 70 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710
  1. var coreml = coreml || {};
  2. var json = json || require('./json');
  3. var protobuf = protobuf || require('./protobuf');
  4. var base = base || require('./base');
  5. coreml.ModelFactory = class {
  6. match(context) {
  7. const stream = context.stream;
  8. const identifier = context.identifier.toLowerCase();
  9. const extension = identifier.split('.').pop().toLowerCase();
  10. const tags = context.tags('pb');
  11. if (tags.get(1) === 0 && tags.get(2) === 2) {
  12. if (extension === 'pb') {
  13. const tags = context.tags('pb+');
  14. const keys = Object.keys(tags).map((key) => parseInt(key, 10));
  15. const match = (key) =>
  16. (key >= 200 && key < 220) ||
  17. (key >= 300 && key < 320) ||
  18. (key >= 400 && key < 420) ||
  19. (key >= 500 && key < 520) ||
  20. (key >= 550 && key < 560) ||
  21. (key >= 600 && key < 620) ||
  22. (key === 900) ||
  23. (key >= 2000 && key < 2010) ||
  24. (key === 3000);
  25. if (!keys.some((key) => match(key))) {
  26. return null;
  27. }
  28. }
  29. return 'coreml.pb';
  30. }
  31. if (identifier === 'manifest.json') {
  32. const obj = context.open('json');
  33. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  34. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  35. if (entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel').length === 1)){
  36. return 'coreml.manifest';
  37. }
  38. }
  39. }
  40. if (identifier === 'metadata.json') {
  41. const obj = context.open('json');
  42. if (obj && obj.rootModelIdentifier && obj.itemInfoEntries) {
  43. return 'coreml.metadata';
  44. }
  45. }
  46. if (identifier === 'featuredescriptions.json') {
  47. const obj = context.open('json');
  48. if (obj && (obj.Inputs || obj.Outputs)) {
  49. return 'coreml.featuredescriptions';
  50. }
  51. }
  52. if (extension === 'bin' && stream.length > 16) {
  53. const buffer = stream.peek(Math.min(256, stream.length));
  54. for (let i = 0; i < buffer.length - 4; i++) {
  55. const signature = (buffer[i] | buffer[i + 1] << 8 | buffer[i + 2] << 16 | buffer [i + 3] << 24) >>> 0;
  56. if (signature === 0xdeadbeef) {
  57. return 'coreml.weights';
  58. }
  59. }
  60. }
  61. return undefined;
  62. }
  63. open(context, match) {
  64. return context.require('./coreml-proto').then(() => {
  65. return coreml.Metadata.open(context).then((metadata) => {
  66. const openModel = (stream, context, path, format) => {
  67. let model = null;
  68. try {
  69. coreml.proto = protobuf.get('coreml').CoreML.Specification;
  70. const reader = protobuf.BinaryReader.open(stream);
  71. model = coreml.proto.Model.decode(reader);
  72. }
  73. catch (error) {
  74. const message = error && error.message ? error.message : error.toString();
  75. throw new coreml.Error('File format is not coreml.Model (' + message.replace(/\.$/, '') + ').');
  76. }
  77. const weightPaths = new Set();
  78. const walkProgram = (program) => {
  79. for (const entry of Object.entries(program.functions)) {
  80. const func = entry[1];
  81. for (const entry of Object.entries(func.block_specializations)) {
  82. const block = entry[1];
  83. for (const operation of block.operations) {
  84. for (const entry of Object.entries(operation.attributes)) {
  85. const value = entry[1];
  86. if (value.blobFileValue && value.blobFileValue.fileName) {
  87. weightPaths.add(value.blobFileValue.fileName);
  88. }
  89. }
  90. }
  91. }
  92. }
  93. };
  94. const walkModel = (model) => {
  95. if (model.mlProgram) {
  96. walkProgram(model.mlProgram);
  97. }
  98. if (model.pipeline && model.pipeline.models) {
  99. for (const node of model.pipeline.models) {
  100. walkModel(node);
  101. }
  102. }
  103. if (model.pipelineClassifier && model.pipelineClassifier.pipeline && model.pipelineClassifier.pipeline.models) {
  104. for (const node of model.pipelineClassifier.pipeline.models) {
  105. walkModel(node);
  106. }
  107. }
  108. if (model.pipelineRegressor && model.pipelineRegressor.pipeline && model.pipelineRegressor.pipeline.models) {
  109. for (const node of model.pipelineRegressor.pipeline.models) {
  110. walkModel(node);
  111. }
  112. }
  113. };
  114. walkModel(model);
  115. if (weightPaths.size > 0) {
  116. const items = path.split('/');
  117. items.pop();
  118. const folder = items.join('/');
  119. const keys = Array.from(weightPaths);
  120. const paths = keys.map((path) => {
  121. const items = path.split('/');
  122. if (items[0] === '@model_path') {
  123. items[0] = folder;
  124. }
  125. return items.join('/');
  126. });
  127. const promises = paths.map((path) => context.request(path, null));
  128. return Promise.all(promises).then((streams) => {
  129. const weights = new Map();
  130. for (let i = 0; i < keys.length; i++) {
  131. weights.set(keys[i], streams[i]);
  132. }
  133. return new coreml.Model(metadata, format, model, weights);
  134. }).catch((/* err */) => {
  135. return new coreml.Model(metadata, format, model, new Map());
  136. });
  137. }
  138. return new coreml.Model(metadata, format, model, new Map());
  139. };
  140. const openManifest = (obj, context, path) => {
  141. const entries = Object.keys(obj.itemInfoEntries).map((key) => obj.itemInfoEntries[key]);
  142. const entry = entries.filter((entry) => entry.path.toLowerCase().endsWith('.mlmodel'))[0];
  143. const file = path + 'Data/' + entry.path;
  144. return context.request(file, null).then((stream) => {
  145. return openModel(stream, context, file, 'Core ML Package');
  146. });
  147. };
  148. const openManifestStream = (context, path) => {
  149. return context.request(path + 'Manifest.json', null).then((stream) => {
  150. const reader = json.TextReader.open(stream);
  151. const obj = reader.read();
  152. return openManifest(obj, context, path);
  153. });
  154. };
  155. switch (match) {
  156. case 'coreml.pb': {
  157. return openModel(context.stream, context, context.identifier);
  158. }
  159. case 'coreml.manifest': {
  160. const obj = context.open('json');
  161. return openManifest(obj, context, '');
  162. }
  163. case 'coreml.featuredescriptions':
  164. case 'coreml.metadata': {
  165. return openManifestStream(context, '../../');
  166. }
  167. case 'coreml.weights': {
  168. return openManifestStream(context, '../../../');
  169. }
  170. default: {
  171. throw new coreml.Error("Unsupported Core ML format '" + match + "'.");
  172. }
  173. }
  174. });
  175. });
  176. }
  177. };
  178. coreml.Model = class {
  179. constructor(metadata, format, model, weights) {
  180. this._format = (format || 'Core ML') + ' v' + model.specificationVersion.toString();
  181. this._metadata = [];
  182. this._graphs = [ new coreml.Graph(metadata, model, weights) ];
  183. if (model.description && model.description.metadata) {
  184. const properties = model.description.metadata;
  185. if (properties.versionString) {
  186. this._version = properties.versionString;
  187. }
  188. if (properties.shortDescription) {
  189. this._description = properties.shortDescription;
  190. }
  191. if (properties.author) {
  192. this._metadata.push({ name: 'author', value: properties.author });
  193. }
  194. if (properties.license) {
  195. this._metadata.push({ name: 'license', value: properties.license });
  196. }
  197. if (metadata.userDefined && Object.keys(properties.userDefined).length > 0) {
  198. /* empty */
  199. }
  200. }
  201. }
  202. get format() {
  203. return this._format;
  204. }
  205. get version() {
  206. return this._version || null;
  207. }
  208. get description() {
  209. return this._description || null;
  210. }
  211. get metadata() {
  212. return this._metadata;
  213. }
  214. get graphs() {
  215. return this._graphs;
  216. }
  217. };
  218. coreml.Graph = class {
  219. constructor(metadata, model, weights) {
  220. this._metadata = metadata;
  221. this._description = model.description;
  222. this._groups = false;
  223. this._inputs = [];
  224. this._outputs = [];
  225. this._nodes = [];
  226. if (this._description) {
  227. this._inputs = this._description.input.map((input) => {
  228. const argument = new coreml.Argument(input.name, coreml.Utility.featureType(input.type), input.shortDescription, null);
  229. return new coreml.Parameter(input.name, true, [ argument ]);
  230. });
  231. this._outputs = this._description.output.map((output) => {
  232. const argument = new coreml.Argument(output.name, coreml.Utility.featureType(output.type), output.shortDescription, null);
  233. return new coreml.Parameter(output.name, true, [ argument ]);
  234. });
  235. }
  236. this._type = this._loadModel(model, {}, '', weights);
  237. }
  238. get name() {
  239. return '';
  240. }
  241. get type() {
  242. return this._type;
  243. }
  244. get inputs() {
  245. return this._inputs;
  246. }
  247. get outputs() {
  248. return this._outputs;
  249. }
  250. get nodes() {
  251. return this._nodes;
  252. }
  253. get groups() {
  254. return this._groups;
  255. }
  256. _updateOutput(name, newName) {
  257. for (const node of this._nodes) {
  258. for (const output of node.outputs) {
  259. for (const argument of output.arguments) {
  260. if (argument.name === name) {
  261. argument.name = newName;
  262. }
  263. }
  264. }
  265. }
  266. return newName;
  267. }
  268. _updateClassifierOutput(group, classifier) {
  269. let labelProbabilityLayerName = classifier.labelProbabilityLayerName;
  270. if (!labelProbabilityLayerName && this._nodes.length > 0) {
  271. const node = this._nodes.slice(-1).pop();
  272. if (node && node.outputs.length == 1 && node.outputs[0].arguments.length == 1) {
  273. labelProbabilityLayerName = node.outputs[0].arguments[0].name;
  274. }
  275. }
  276. let predictedFeatureName = this._description.predictedFeatureName;
  277. let predictedProbabilitiesName = this._description.predictedProbabilitiesName;
  278. if ((predictedFeatureName || predictedProbabilitiesName) && labelProbabilityLayerName && classifier.ClassLabels) {
  279. predictedFeatureName = predictedFeatureName ? predictedFeatureName : '?';
  280. predictedProbabilitiesName = predictedProbabilitiesName ? predictedProbabilitiesName : '?';
  281. const labelProbabilityInput = this._updateOutput(labelProbabilityLayerName, labelProbabilityLayerName + ':labelProbabilityLayerName');
  282. const type = classifier.ClassLabels;
  283. const inputs = [
  284. new coreml.Parameter('input', true, [ new coreml.Argument(labelProbabilityInput) ])
  285. ];
  286. const outputs = [
  287. new coreml.Parameter('probabilities', true, [ new coreml.Argument(predictedProbabilitiesName) ]),
  288. new coreml.Parameter('feature', true, [ new coreml.Argument(predictedFeatureName) ])
  289. ];
  290. const node = new coreml.Node(this._metadata, this._group, type, null, '', classifier[type], inputs, outputs);
  291. this._nodes.push(node);
  292. }
  293. }
  294. _updatePreprocessing(scope, group, preprocessing) {
  295. if (preprocessing && preprocessing.length > 0) {
  296. const preprocessingInput = this._description.input[0].name;
  297. const inputNodes = [];
  298. for (const node of this._nodes) {
  299. if (node.inputs.some((input) => input.arguments.some((arg) => arg.name == preprocessingInput))) {
  300. inputNodes.push(node);
  301. }
  302. }
  303. let preprocessorOutput = preprocessingInput;
  304. let preprocessorIndex = 0;
  305. for (const p of preprocessing) {
  306. const input = p.featureName ? p.featureName : preprocessorOutput;
  307. preprocessorOutput = preprocessingInput + ':' + preprocessorIndex.toString();
  308. this._createNode(scope, group, p.preprocessor, null, '', p[p.preprocessor], [ input ], [ preprocessorOutput ]);
  309. preprocessorIndex++;
  310. }
  311. for (const node of inputNodes) {
  312. for (const input of node.inputs) {
  313. for (const arg of input.arguments) {
  314. if (arg.name === preprocessingInput) {
  315. arg.name = preprocessorOutput;
  316. }
  317. }
  318. }
  319. }
  320. }
  321. }
  322. _loadModel(model, scope, group, weights) {
  323. this._groups = this._groups | (group.length > 0 ? true : false);
  324. const description = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : '';
  325. switch (model.Type) {
  326. case 'neuralNetworkClassifier': {
  327. const neuralNetworkClassifier = model.neuralNetworkClassifier;
  328. for (const layer of neuralNetworkClassifier.layers) {
  329. this._createNode(scope, group, layer.layer, layer.name, group === '' ? '' : description, layer[layer.layer], layer.input, layer.output);
  330. }
  331. this._updateClassifierOutput(group, neuralNetworkClassifier);
  332. this._updatePreprocessing(scope, group, neuralNetworkClassifier.preprocessing);
  333. return 'Neural Network Classifier';
  334. }
  335. case 'neuralNetwork': {
  336. const neuralNetwork = model.neuralNetwork;
  337. for (const layer of neuralNetwork.layers) {
  338. this._createNode(scope, group, layer.layer, layer.name, group === '' ? '' : description, layer[layer.layer], layer.input, layer.output);
  339. }
  340. this._updatePreprocessing(scope, group, neuralNetwork.preprocessing);
  341. return 'Neural Network';
  342. }
  343. case 'neuralNetworkRegressor': {
  344. const neuralNetworkRegressor = model.neuralNetworkRegressor;
  345. for (const layer of neuralNetworkRegressor.layers) {
  346. this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
  347. }
  348. this._updatePreprocessing(scope, group, neuralNetworkRegressor);
  349. return 'Neural Network Regressor';
  350. }
  351. case 'pipeline': {
  352. for (let i = 0; i < model.pipeline.models.length; i++) {
  353. this._loadModel(model.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipeline[' + i.toString() + ']');
  354. }
  355. return 'Pipeline';
  356. }
  357. case 'pipelineClassifier': {
  358. for (let i = 0; i < model.pipelineClassifier.pipeline.models.length; i++) {
  359. this._loadModel(model.pipelineClassifier.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineClassifier[' + i.toString() + ']');
  360. }
  361. return 'Pipeline Classifier';
  362. }
  363. case 'pipelineRegressor': {
  364. for (let i = 0; i < model.pipelineRegressor.pipeline.models.length; i++) {
  365. this._loadModel(model.pipelineRegressor.pipeline.models[i], scope, (group ? (group + '/') : '') + 'pipelineRegressor[' + i.toString() + ']');
  366. }
  367. return 'Pipeline Regressor';
  368. }
  369. case 'glmClassifier': {
  370. this._createNode(scope, group, 'glmClassifier', null, description,
  371. {
  372. classEncoding: model.glmClassifier.classEncoding,
  373. offset: model.glmClassifier.offset,
  374. weights: model.glmClassifier.weights
  375. },
  376. [ model.description.input[0].name ],
  377. [ model.description.predictedProbabilitiesName ]);
  378. this._updateClassifierOutput(group, model.glmClassifier);
  379. return 'Generalized Linear Classifier';
  380. }
  381. case 'glmRegressor': {
  382. this._createNode(scope, group, 'glmRegressor', null, description,
  383. model.glmRegressor,
  384. [ model.description.input[0].name ],
  385. [ model.description.output[0].name ]);
  386. return 'Generalized Linear Regressor';
  387. }
  388. case 'dictVectorizer': {
  389. this._createNode(scope, group, 'dictVectorizer', null, description,
  390. model.dictVectorizer,
  391. [ model.description.input[0].name ],
  392. [ model.description.output[0].name ]);
  393. return 'Dictionary Vectorizer';
  394. }
  395. case 'featureVectorizer': {
  396. this._createNode(scope, group, 'featureVectorizer', null, description,
  397. model.featureVectorizer,
  398. coreml.Graph._formatFeatureDescriptionList(model.description.input),
  399. [ model.description.output[0].name ]);
  400. return 'Feature Vectorizer';
  401. }
  402. case 'treeEnsembleClassifier': {
  403. this._createNode(scope, group, 'treeEnsembleClassifier', null, description,
  404. model.treeEnsembleClassifier.treeEnsemble,
  405. [ model.description.input[0].name ],
  406. [ model.description.output[0].name ]);
  407. this._updateClassifierOutput(group, model.treeEnsembleClassifier);
  408. return 'Tree Ensemble Classifier';
  409. }
  410. case 'treeEnsembleRegressor': {
  411. this._createNode(scope, group, 'treeEnsembleRegressor', null, description,
  412. model.treeEnsembleRegressor.treeEnsemble,
  413. [ model.description.input[0].name ],
  414. [ model.description.output[0].name ]);
  415. return 'Tree Ensemble Regressor';
  416. }
  417. case 'supportVectorClassifier': {
  418. this._createNode(scope, group, 'supportVectorClassifier', null, description,
  419. {
  420. coefficients: model.supportVectorClassifier.coefficients,
  421. denseSupportVectors: model.supportVectorClassifier.denseSupportVectors,
  422. kernel: model.supportVectorClassifier.kernel,
  423. numberOfSupportVectorsPerClass: model.supportVectorClassifier.numberOfSupportVectorsPerClass,
  424. probA: model.supportVectorClassifier.probA,
  425. probB: model.supportVectorClassifier.probB,
  426. rho: model.supportVectorClassifier.rho,
  427. supportVectors: model.supportVectorClassifier.supportVectors
  428. },
  429. [ model.description.input[0].name ],
  430. [ model.description.output[0].name ]);
  431. this._updateClassifierOutput(group, model.supportVectorClassifier);
  432. return 'Support Vector Classifier';
  433. }
  434. case 'supportVectorRegressor': {
  435. this._createNode(scope, group, 'supportVectorRegressor', null, description,
  436. {
  437. coefficients: model.supportVectorRegressor.coefficients,
  438. kernel: model.supportVectorRegressor.kernel,
  439. rho: model.supportVectorRegressor.rho,
  440. supportVectors: model.supportVectorRegressor.supportVectors
  441. },
  442. [ model.description.input[0].name ],
  443. [ model.description.output[0].name ]);
  444. return 'Support Vector Regressor';
  445. }
  446. case 'arrayFeatureExtractor': {
  447. this._createNode(scope, group, 'arrayFeatureExtractor', null, description,
  448. { extractIndex: model.arrayFeatureExtractor.extractIndex },
  449. [ model.description.input[0].name ],
  450. [ model.description.output[0].name ]);
  451. return 'Array Feature Extractor';
  452. }
  453. case 'oneHotEncoder': {
  454. const categoryType = model.oneHotEncoder.CategoryType;
  455. const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse };
  456. oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType];
  457. this._createNode(scope, group, 'oneHotEncoder', null, description,
  458. oneHotEncoderParams,
  459. [ model.description.input[0].name ],
  460. [ model.description.output[0].name ]);
  461. return 'One Hot Encoder';
  462. }
  463. case 'imputer': {
  464. const imputedValue = model.imputer.ImputedValue;
  465. const replaceValue = model.imputer.ReplaceValue;
  466. const imputerParams = {};
  467. imputerParams[imputedValue] = model.imputer[imputedValue];
  468. imputerParams[replaceValue] = model.imputer[replaceValue];
  469. this._createNode(scope, group, 'oneHotEncoder', null, description,
  470. imputerParams,
  471. [ model.description.input[0].name ],
  472. [ model.description.output[0].name ]);
  473. return 'Imputer';
  474. }
  475. case 'normalizer': {
  476. this._createNode(scope, group, 'normalizer', null, description,
  477. model.normalizer,
  478. [ model.description.input[0].name ],
  479. [ model.description.output[0].name ]);
  480. return 'Normalizer';
  481. }
  482. case 'wordTagger': {
  483. this._createNode(scope, group, 'wordTagger', null, description,
  484. model.wordTagger,
  485. [ model.description.input[0].name ],
  486. [
  487. model.wordTagger.tokensOutputFeatureName,
  488. model.wordTagger.tokenTagsOutputFeatureName,
  489. model.wordTagger.tokenLocationsOutputFeatureName,
  490. model.wordTagger.tokenLengthsOutputFeatureName
  491. ]);
  492. return 'Word Tagger';
  493. }
  494. case 'textClassifier': {
  495. this._createNode(scope, group, 'textClassifier', null, description,
  496. model.textClassifier,
  497. [ model.description.input[0].name ],
  498. [ model.description.output[0].name ]);
  499. return 'Text Classifier';
  500. }
  501. case 'nonMaximumSuppression': {
  502. const nonMaximumSuppressionParams = {
  503. pickTop: model.nonMaximumSuppression.pickTop,
  504. stringClassLabels: model.nonMaximumSuppression.stringClassLabels,
  505. iouThreshold: model.nonMaximumSuppression.iouThreshold,
  506. confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold
  507. };
  508. this._createNode(scope, group, 'nonMaximumSuppression', null, description,
  509. nonMaximumSuppressionParams,
  510. [
  511. model.nonMaximumSuppression.confidenceInputFeatureName,
  512. model.nonMaximumSuppression.coordinatesInputFeatureName,
  513. model.nonMaximumSuppression.iouThresholdInputFeatureName,
  514. model.nonMaximumSuppression.confidenceThresholdInputFeatureName,
  515. ],
  516. [
  517. model.nonMaximumSuppression.confidenceOutputFeatureName,
  518. model.nonMaximumSuppression.coordinatesOutputFeatureName
  519. ]);
  520. return 'Non Maximum Suppression';
  521. }
  522. case 'visionFeaturePrint': {
  523. const visionFeaturePrintParams = {
  524. scene: model.visionFeaturePrint.scene
  525. };
  526. this._createNode(scope, group, 'visionFeaturePrint', null, description,
  527. visionFeaturePrintParams,
  528. [ model.description.input[0].name ],
  529. [ model.description.output[0].name ]);
  530. return 'Vision Feature Print';
  531. }
  532. case 'soundAnalysisPreprocessing': {
  533. this._createNode(scope, group, 'soundAnalysisPreprocessing', null, description,
  534. model.soundAnalysisPreprocessing,
  535. [ model.description.input[0].name ],
  536. [ model.description.output[0].name ]);
  537. return 'Sound Analysis Preprocessing';
  538. }
  539. case 'kNearestNeighborsClassifier': {
  540. this._createNode(scope, group, 'kNearestNeighborsClassifier', null, description,
  541. model.kNearestNeighborsClassifier,
  542. [ model.description.input[0].name ],
  543. [ model.description.output[0].name ]);
  544. this._updateClassifierOutput(group, model.kNearestNeighborsClassifier);
  545. return 'Nearest Neighbors Classifier';
  546. }
  547. case 'itemSimilarityRecommender': {
  548. this._createNode(scope, group, 'itemSimilarityRecommender', null, description,
  549. {
  550. itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector,
  551. itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities
  552. },
  553. model.description.input.map((feature) => feature.name),
  554. model.description.output.map((feature) => feature.name));
  555. return 'Item Similarity Recommender';
  556. }
  557. case 'audioFeaturePrint': {
  558. this._createNode(scope, group, 'audioFeaturePrint', null, description,
  559. model.audioFeaturePrint,
  560. [ model.description.input[0].name ],
  561. [ model.description.output[0].name ]);
  562. return 'Audio Feature Print';
  563. }
  564. case 'linkedModel': {
  565. this._createNode(scope, group, 'linkedModel', null, description,
  566. model.linkedModel.linkedModelFile,
  567. [ model.description.input[0].name ],
  568. [ model.description.output[0].name ]);
  569. return 'Linked Model';
  570. }
  571. case 'customModel': {
  572. this._createNode(scope, group, 'customModel', null, description,
  573. { className: model.customModel.className, parameters: model.customModel.parameters },
  574. [ model.description.input[0].name ],
  575. [ model.description.output[0].name ]);
  576. return 'customModel';
  577. }
  578. case 'mlProgram': {
  579. return this._loadProgram(model.mlProgram, scope, group, weights);
  580. }
  581. default: {
  582. throw new coreml.Error("Unsupported model type '" + JSON.stringify(Object.keys(model)) + "'.");
  583. }
  584. }
  585. }
  586. _loadProgram(program, scope, group, weights) {
  587. // TODO: need to handle functions other than main?
  588. const main = program.functions.main;
  589. // TODO: need to handle more than one block specialization?
  590. const block = main.block_specializations.CoreML5 || main.block_specializations.CoreML6;
  591. const convertValue = (value) => {
  592. switch (value.value) {
  593. case 'immediateValue': {
  594. const tensor = value.immediateValue.tensor;
  595. let values = null;
  596. switch (tensor.value) {
  597. case 'ints':
  598. values = tensor.ints.values;
  599. break;
  600. case 'strings':
  601. values = tensor.strings.values;
  602. break;
  603. case 'bools':
  604. values = tensor.bools.values;
  605. break;
  606. case 'floats':
  607. values = tensor.floats.values;
  608. break;
  609. case 'bytes':
  610. values = tensor.bytes.values;
  611. break;
  612. default:
  613. throw new coreml.Error("Unsupported tensor value '" + tensor.value + "'.");
  614. }
  615. return values;
  616. }
  617. case 'blobFileValue': {
  618. const type = coreml.Utility.valueType(value.type);
  619. const blob = value.blobFileValue;
  620. const offset = blob.offset.toNumber();
  621. const file = blob.fileName;
  622. let data = null;
  623. const stream = weights.get(file);
  624. if (stream) {
  625. stream.seek(offset);
  626. const buffer = stream.read(32);
  627. const reader = new base.BinaryReader(buffer);
  628. const signature = reader.uint32();
  629. if (signature == 0xdeadbeef) {
  630. reader.uint32(); // dataType
  631. const size = reader.uint64();
  632. stream.seek(reader.uint64());
  633. const length = (type.shape.dimensions || []).reduce((a, b) => a * b, 1);
  634. switch (type.dataType) {
  635. case 'float32': {
  636. const buffer = stream.read(size);
  637. data = new Float32Array(buffer.buffer, buffer.byteOffset, length).slice();
  638. break;
  639. }
  640. case 'float16': {
  641. data = stream.read(size);
  642. break;
  643. }
  644. case 'uint8': {
  645. data = stream.read(size);
  646. break;
  647. }
  648. default:
  649. throw new coreml.Error("Unsupported blob data type '" + type.dataType + "'.");
  650. }
  651. }
  652. }
  653. return new coreml.Tensor('Blob', type, data);
  654. }
  655. default: {
  656. throw new coreml.Error("Unsupported value '" + value.value + "'.");
  657. }
  658. }
  659. };
  660. const args = new Map();
  661. const arg = (name) => {
  662. if (!args.has(name)) {
  663. args.set(name, { name: name, to: [], from: [] });
  664. }
  665. return args.get(name);
  666. };
  667. const operations = block.operations.map((op) => {
  668. const operation = {
  669. type: op.type,
  670. attributes: {}
  671. };
  672. for (const entry of Object.entries(op.attributes)) {
  673. const key = entry[0];
  674. const value = entry[1];
  675. operation.attributes[key] = convertValue(value);
  676. }
  677. operation.inputs = Object.entries(op.inputs).map((entry) => {
  678. const key = entry[0];
  679. const input = entry[1];
  680. const args = input.arguments.map((argument) => {
  681. if (argument.name) {
  682. const value = arg(argument.name);
  683. value.to.push(operation);
  684. return value;
  685. }
  686. return { value: argument.value };
  687. });
  688. return {
  689. name: key,
  690. arguments: args
  691. };
  692. });
  693. operation.outputs = op.outputs.map((output) => {
  694. const value = arg(output.name);
  695. value.type = coreml.Utility.valueType(output.type);
  696. value.from.push(operation);
  697. return {
  698. name: 'output',
  699. arguments: [ value ]
  700. };
  701. });
  702. return operation;
  703. });
  704. for (const op of operations) {
  705. if (op.type === 'const' && op.inputs.length === 0 &&
  706. op.outputs.length === 1 && op.outputs[0].arguments.length === 1) {
  707. const argument = op.outputs[0].arguments[0];
  708. if (op.attributes && op.attributes.val) {
  709. const type = argument.type;
  710. const data = op.attributes.val;
  711. if (data instanceof Uint8Array && data.length === 2 &&
  712. type.dataType === 'float16' && type.shape.dimensions.length === 0) {
  713. const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
  714. argument.value = view.getFloat16(0, true);
  715. }
  716. else {
  717. argument.value = data;
  718. }
  719. argument.const = true;
  720. op.delete = true;
  721. }
  722. }
  723. }
  724. for (const op of operations) {
  725. for (const input of op.inputs) {
  726. if (input.arguments.length > 1 && input.arguments.some((argument) => argument.const)) {
  727. if (input.arguments.every((argument) => argument.value instanceof coreml.Tensor)) {
  728. continue;
  729. }
  730. for (const argument of input.arguments) {
  731. for (const from of argument.from) {
  732. from.delete = false;
  733. }
  734. delete argument.value;
  735. }
  736. }
  737. }
  738. }
  739. for (const op of operations) {
  740. if (op.delete) {
  741. continue;
  742. }
  743. op.inputs = op.inputs.filter((input) => {
  744. if (input.arguments.every((argument) => argument.value === undefined || argument.value instanceof coreml.Tensor)) {
  745. return true;
  746. }
  747. if (input.arguments.length === 1) {
  748. const argument = input.arguments[0];
  749. op.attributes[input.name] = argument.value;
  750. return false;
  751. }
  752. op.attributes[input.name] = input.arguments.map((argument) => argument.value[0]);
  753. return false;
  754. });
  755. }
  756. const tensors = new Map();
  757. const tensor = (arg) => {
  758. if (!tensors.has(arg.name)) {
  759. tensors.set(arg.name, new coreml.Argument(arg.name, arg.type, null, arg.value));
  760. }
  761. return tensors.get(arg.name);
  762. };
  763. for (const op of operations) {
  764. if (op.delete) {
  765. continue;
  766. }
  767. op.inputs = op.inputs.map((input) => new coreml.Parameter(input.name, true, input.arguments.map((argument) => tensor(argument))));
  768. op.outputs = op.outputs.map((output) => new coreml.Parameter(output.name, true, output.arguments.map((argument) => tensor(argument))));
  769. }
  770. for (const op of operations.filter((op) => !op.delete)) {
  771. const type = 'program:' + op.type;
  772. const metadata = this._metadata.type(type);
  773. if (metadata && Array.isArray(metadata.inputs)) {
  774. let index = 1;
  775. const map = new Map(metadata.inputs.map((input) => [ input.name, index++ ]));
  776. op.inputs.sort((a, b) => (map.get(a.name) || map.size) - (map.get(b.name) || map.size));
  777. }
  778. const node = new coreml.Node(this._metadata, group, type, null, null, op.attributes, op.inputs, op.outputs);
  779. this._nodes.push(node);
  780. }
  781. return 'ML Program';
  782. }
  783. _createNode(scope, group, type, name, description, data, inputs, outputs, outputTypes) {
  784. inputs = inputs.map((input) => scope[input] ? scope[input].argument : input);
  785. outputs = outputs.map((output) => {
  786. if (scope[output]) {
  787. scope[output].counter++;
  788. const next = output + '\n' + scope[output].counter.toString(); // custom argument id
  789. scope[output].argument = next;
  790. return next;
  791. }
  792. scope[output] = {
  793. argument: output,
  794. counter: 0
  795. };
  796. return output;
  797. });
  798. const initializers = [];
  799. const attributes = {};
  800. if (data) {
  801. const map = this._initialize(type, data, initializers);
  802. for (const key of Object.keys(data)) {
  803. if (map[key]) {
  804. continue;
  805. }
  806. attributes[key] = data[key];
  807. }
  808. }
  809. const inputParameters = this._metadata.getInputs(type, inputs).map((input) => {
  810. return new coreml.Parameter(input.name, true, input.arguments.map((argument) => {
  811. return new coreml.Argument(argument.name, argument.type, null, null);
  812. }));
  813. });
  814. inputParameters.push(...initializers);
  815. const outputParameters = outputs.map((output, index) => {
  816. const name = this._metadata.getOutputName(type, index);
  817. const outputType = outputTypes ? outputTypes[index] : null;
  818. return new coreml.Parameter(name, true, [ new coreml.Argument(output, outputType, null, null) ]);
  819. });
  820. const node = new coreml.Node(this._metadata, group, type, name, description, attributes, inputParameters, outputParameters);
  821. this._nodes.push(node);
  822. return node;
  823. }
  824. _initializer(type, initializers, kind, name, shape, data) {
  825. let dataType = '?';
  826. let quantization = null;
  827. let values = null;
  828. if (data) {
  829. if (data.floatValue && data.floatValue.length > 0) {
  830. values = data.floatValue;
  831. dataType = 'float32';
  832. }
  833. else if (data.float16Value && data.float16Value.length > 0) {
  834. values = data.float16Value; // byte[]
  835. dataType = 'float16';
  836. }
  837. else if (data.rawValue && data.rawValue.length > 0) {
  838. if (data.quantization) {
  839. values = data.rawValue;
  840. dataType = 'uint' + data.quantization.numberOfBits.toString();
  841. }
  842. else {
  843. shape = [];
  844. }
  845. }
  846. quantization = data.quantization || null;
  847. }
  848. const tensorType = new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  849. const tensor = new coreml.Tensor(kind, tensorType, values, quantization);
  850. const argument = new coreml.Argument('', null, null, tensor);
  851. const visible = this._metadata.visible(type, name);
  852. initializers.push(new coreml.Parameter(name, visible, [ argument ]));
  853. }
  854. _initialize(type, data, initializers) {
  855. switch (type) {
  856. case 'convolution': {
  857. const weightsShape = [ data.outputChannels, data.kernelChannels, data.kernelSize[0], data.kernelSize[1] ];
  858. if (data.isDeconvolution) {
  859. weightsShape[0] = data.kernelChannels;
  860. weightsShape[1] = Math.floor(data.outputChannels / (data.nGroups != 0 ? data.nGroups : 1));
  861. }
  862. this._initializer(type, initializers, 'Weights', 'weights', weightsShape, data.weights);
  863. if (data.hasBias) {
  864. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputChannels ], data.bias);
  865. }
  866. return { 'weights': true, 'bias': data.hasBias };
  867. }
  868. case 'innerProduct':
  869. this._initializer(type, initializers, 'Weights', 'weights', [ data.outputChannels, data.inputChannels ], data.weights);
  870. if (data.hasBias) {
  871. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputChannels ], data.bias);
  872. }
  873. return { 'weights': true, 'bias': data.hasBias };
  874. case 'batchnorm':
  875. this._initializer(type, initializers, 'Weights', 'gamma', [ data.channels ], data.gamma);
  876. this._initializer(type, initializers, 'Weights', 'beta', [ data.channels ], data.beta);
  877. if (data.mean) {
  878. this._initializer(type, initializers, 'Weights', 'mean', [ data.channels ], data.mean);
  879. }
  880. if (data.variance) {
  881. this._initializer(type, initializers, 'Weights', 'variance', [ data.channels ], data.variance);
  882. }
  883. return { 'gamma': true, 'beta': true, 'mean': true, 'variance': true };
  884. case 'embedding':
  885. this._initializer(type, initializers, 'Weights', 'weights', [ data.inputDim, data.outputChannels ], data.weights);
  886. return { 'weights': true };
  887. case 'loadConstant':
  888. case 'loadConstantND':
  889. this._initializer(type, initializers, 'Weights', 'data', data.shape, data.data);
  890. return { 'data': true };
  891. case 'scale':
  892. this._initializer(type, initializers, 'Weights', 'scale', data.shapeScale, data.scale);
  893. if (data.hasBias) {
  894. this._initializer(type, initializers, 'Weights', 'bias', data.shapeBias, data.bias);
  895. }
  896. return { 'scale': true, 'bias': data.hasBias };
  897. case 'bias':
  898. this._initializer(type, initializers, 'Weights', 'bias', data.shape, data.bias);
  899. return { 'bias': true };
  900. case 'simpleRecurrent':
  901. this._initializer(type, initializers, 'Weights', 'weights', [ data.outputVectorSize, data.inputVectorSize ], data.weightMatrix);
  902. this._initializer(type, initializers, 'Weights', 'recurrent', [ data.outputVectorSize, data.inputVectorSize ], data.recursionMatrix);
  903. if (data.hasBiasVectors) {
  904. this._initializer(type, initializers, 'Weights', 'bias', [ data.outputVectorSize ], data.biasVector);
  905. }
  906. return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
  907. case 'gru': {
  908. const recursionMatrixShape = [ data.outputVectorSize, data.outputVectorSize ];
  909. const weightMatrixShape = [ data.outputVectorSize, data.inputVectorSize ];
  910. const biasVectorShape = [ data.outputVectorSize ];
  911. this._initializer(type, initializers, 'Weights', 'updateGateWeightMatrix', weightMatrixShape, data.updateGateWeightMatrix);
  912. this._initializer(type, initializers, 'Weights', 'resetGateWeightMatrix', weightMatrixShape, data.resetGateWeightMatrix);
  913. this._initializer(type, initializers, 'Weights', 'outputGateWeightMatrix', weightMatrixShape, data.outputGateWeightMatrix);
  914. this._initializer(type, initializers, 'Weights', 'updateGateRecursionMatrix', recursionMatrixShape, data.updateGateRecursionMatrix);
  915. this._initializer(type, initializers, 'Weights', 'resetGateRecursionMatrix', recursionMatrixShape, data.resetGateRecursionMatrix);
  916. this._initializer(type, initializers, 'Weights', 'outputGateRecursionMatrix', recursionMatrixShape, data.outputGateRecursionMatrix);
  917. if (data.hasBiasVectors) {
  918. this._initializer(type, initializers, 'Weights', 'updateGateBiasVector', biasVectorShape, data.updateGateBiasVector);
  919. this._initializer(type, initializers, 'Weights', 'resetGateBiasVector', biasVectorShape, data.resetGateBiasVector);
  920. this._initializer(type, initializers, 'Weights', 'outputGateBiasVector', biasVectorShape, data.outputGateBiasVector);
  921. }
  922. return {
  923. 'updateGateWeightMatrix': true, 'resetGateWeightMatrix': true, 'outputGateWeightMatrix': true,
  924. 'updateGateRecursionMatrix': true, 'resetGateRecursionMatrix': true, 'outputGateRecursionMatrix': true,
  925. 'updateGateBiasVector': data.hasBiasVectors, 'resetGateBiasVector': data.hasBiasVectors, 'outputGateBiasVector': data.hasBiasVectors
  926. };
  927. }
  928. case 'uniDirectionalLSTM':
  929. case 'biDirectionalLSTM': {
  930. const count = (type == 'uniDirectionalLSTM') ? 1 : 2;
  931. const matrixShape = [ data.outputVectorSize, data.inputVectorSize ];
  932. const vectorShape = [ data.outputVectorSize ];
  933. for (let i = 0; i < count; i++) {
  934. const weights = count == 1 ? data.weightParams : data.weightParams[i];
  935. const suffix = (i == 0) ? '' : '_rev';
  936. this._initializer(type, initializers, 'Weights', 'inputGateWeightMatrix' + suffix, matrixShape, weights.inputGateWeightMatrix);
  937. this._initializer(type, initializers, 'Weights', 'forgetGateWeightMatrix' + suffix, matrixShape, weights.forgetGateWeightMatrix);
  938. this._initializer(type, initializers, 'Weights', 'blockInputWeightMatrix' + suffix, matrixShape, weights.blockInputWeightMatrix);
  939. this._initializer(type, initializers, 'Weights', 'outputGateWeightMatrix' + suffix, matrixShape, weights.outputGateWeightMatrix);
  940. this._initializer(type, initializers, 'Weights', 'inputGateRecursionMatrix' + suffix, matrixShape, weights.inputGateRecursionMatrix);
  941. this._initializer(type, initializers, 'Weights', 'forgetGateRecursionMatrix' + suffix, matrixShape,weights.forgetGateRecursionMatrix);
  942. this._initializer(type, initializers, 'Weights', 'blockInputRecursionMatrix' + suffix, matrixShape, weights.blockInputRecursionMatrix);
  943. this._initializer(type, initializers, 'Weights', 'outputGateRecursionMatrix' + suffix, matrixShape, weights.outputGateRecursionMatrix);
  944. if (data.params.hasBiasVectors) {
  945. this._initializer(type, initializers, 'Weights', 'inputGateBiasVector' + suffix, vectorShape, weights.inputGateBiasVector);
  946. this._initializer(type, initializers, 'Weights', 'forgetGateBiasVector' + suffix, vectorShape, weights.forgetGateBiasVector);
  947. this._initializer(type, initializers, 'Weights', 'blockInputBiasVector' + suffix, vectorShape, weights.blockInputBiasVector);
  948. this._initializer(type, initializers, 'Weights', 'outputGateBiasVector' + suffix, vectorShape, weights.outputGateBiasVector);
  949. }
  950. if (data.params.hasPeepholeVectors) {
  951. this._initializer(type, initializers, 'Weights', 'inputGatePeepholeVector' + suffix, vectorShape, weights.inputGatePeepholeVector);
  952. this._initializer(type, initializers, 'Weights', 'forgetGatePeepholeVector' + suffix, vectorShape, weights.forgetGatePeepholeVector);
  953. this._initializer(type, initializers, 'Weights', 'outputGatePeepholeVector' + suffix, vectorShape, weights.outputGatePeepholeVector);
  954. }
  955. }
  956. return { 'weightParams': true };
  957. }
  958. case 'dictVectorizer':
  959. data.stringToIndex = this._convertVector(data.stringToIndex);
  960. return {};
  961. case 'wordTagger':
  962. data.modelParameterData = Array.from(data.modelParameterData);
  963. data.stringTags = this._convertVector(data.stringTags);
  964. return { tokensOutputFeatureName: true, tokenTagsOutputFeatureName: true, tokenLengthsOutputFeatureName: true, tokenLocationsOutputFeatureName: true };
  965. case 'textClassifier':
  966. data.modelParameterData = Array.from(data.modelParameterData);
  967. data.stringClassLabels = this._convertVector(data.stringClassLabels);
  968. return {};
  969. case 'nonMaximumSuppression':
  970. data.stringClassLabels = this._convertVector(data.stringClassLabels);
  971. return {};
  972. default:
  973. return {};
  974. }
  975. }
  976. _convertVector(value) {
  977. if (value && Object.keys(value).length == 1 && value.vector) {
  978. return value.vector;
  979. }
  980. return value;
  981. }
  982. static _formatFeatureDescriptionList(list) {
  983. return list.map((item) => item.name);
  984. }
  985. };
  986. coreml.Parameter = class {
  987. constructor(name, visible, args) {
  988. this._name = name;
  989. this._visible = visible;
  990. this._arguments = args;
  991. }
  992. get name() {
  993. return this._name;
  994. }
  995. get visible() {
  996. return this._visible;
  997. }
  998. get arguments() {
  999. return this._arguments;
  1000. }
  1001. };
  1002. coreml.Argument = class {
  1003. constructor(name, type, description, initializer) {
  1004. if (typeof name !== 'string') {
  1005. throw new coreml.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  1006. }
  1007. this._name = name;
  1008. this._type = type;
  1009. this._description = description || null;
  1010. this._initializer = initializer || null;
  1011. }
  1012. get name() {
  1013. return this._name;
  1014. }
  1015. set name(value) {
  1016. this._name = value;
  1017. }
  1018. get type() {
  1019. if (this._initializer) {
  1020. return this._initializer.type;
  1021. }
  1022. return this._type;
  1023. }
  1024. get description() {
  1025. return this._description;
  1026. }
  1027. get quantization() {
  1028. if (this._initializer) {
  1029. return this._initializer.quantization;
  1030. }
  1031. return null;
  1032. }
  1033. get initializer() {
  1034. return this._initializer;
  1035. }
  1036. };
  1037. coreml.Node = class {
  1038. constructor(metadata, group, type, name, description, attributes, inputs, outputs) {
  1039. if (!type) {
  1040. throw new Error('Undefined node type.');
  1041. }
  1042. if (group) {
  1043. this._group = group;
  1044. }
  1045. this._type = Object.assign({}, metadata.type(type) || { name: type });
  1046. this._type.name = type.split(':').pop();
  1047. this._name = name || '';
  1048. this._description = description || '';
  1049. this._inputs = inputs;
  1050. this._outputs = outputs;
  1051. this._attributes = [];
  1052. if (attributes) {
  1053. for (const key of Object.keys(attributes)) {
  1054. const schema = metadata.attribute(type, key);
  1055. const value = attributes[key];
  1056. const attribute = new coreml.Attribute(schema, key, value);
  1057. this._attributes.push(attribute);
  1058. }
  1059. }
  1060. }
  1061. get type() {
  1062. return this._type;
  1063. }
  1064. get name() {
  1065. return this._name;
  1066. }
  1067. get description() {
  1068. return this._description;
  1069. }
  1070. get metadata() {
  1071. return this._metadata;
  1072. }
  1073. get group() {
  1074. return this._group ? this._group : null;
  1075. }
  1076. get inputs() {
  1077. return this._inputs;
  1078. }
  1079. get outputs() {
  1080. return this._outputs;
  1081. }
  1082. get attributes() {
  1083. return this._attributes;
  1084. }
  1085. };
  1086. coreml.Attribute = class {
  1087. constructor(metadata, name, value) {
  1088. this._name = name;
  1089. this._value = value;
  1090. if (this._value instanceof coreml.Tensor) {
  1091. this._type = 'tensor';
  1092. }
  1093. if (metadata) {
  1094. if (metadata.type) {
  1095. this._type = metadata.type;
  1096. }
  1097. if (this._type && coreml.proto) {
  1098. this._value = coreml.Utility.enum(this._type, this._value);
  1099. }
  1100. if (Object.prototype.hasOwnProperty.call(metadata, 'visible') && !metadata.visible) {
  1101. this._visible = false;
  1102. }
  1103. else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  1104. if (Array.isArray(value)) {
  1105. value = value.map((item) => item.toNumber());
  1106. }
  1107. if (JSON.stringify(metadata.default) == JSON.stringify(value)) {
  1108. this._visible = false;
  1109. }
  1110. }
  1111. }
  1112. }
  1113. get name() {
  1114. return this._name;
  1115. }
  1116. get type() {
  1117. return this._type;
  1118. }
  1119. get value() {
  1120. return this._value;
  1121. }
  1122. get visible() {
  1123. return this._visible == false ? false : true;
  1124. }
  1125. };
  1126. coreml.Tensor = class {
  1127. constructor(kind, type, data, quantization) {
  1128. this._kind = kind;
  1129. this._type = type;
  1130. this._data = data;
  1131. this._quantization = quantization;
  1132. }
  1133. get kind() {
  1134. return this._kind;
  1135. }
  1136. get type() {
  1137. return this._type;
  1138. }
  1139. get quantization() {
  1140. if (this._quantization) {
  1141. if (this._quantization.lookupTableQuantization &&
  1142. this._quantization.lookupTableQuantization.floatValue &&
  1143. this._quantization.lookupTableQuantization.floatValue.length > 0) {
  1144. const map = [];
  1145. for (const key of Object.keys(this._quantization.lookupTableQuantization.floatValue)) {
  1146. map.push(key.toString() + ' = ' + this._quantization.lookupTableQuantization.floatValue[key].toString());
  1147. }
  1148. return map.join('; ');
  1149. }
  1150. return '?';
  1151. }
  1152. return null;
  1153. }
  1154. get state() {
  1155. return this._context().state;
  1156. }
  1157. get value() {
  1158. const context = this._context();
  1159. if (context.state) {
  1160. return null;
  1161. }
  1162. context.limit = Number.MAX_SAFE_INTEGER;
  1163. return this._decode(context, 0);
  1164. }
  1165. toString() {
  1166. const context = this._context();
  1167. if (context.state) {
  1168. return '';
  1169. }
  1170. context.limit = 10000;
  1171. const value = this._decode(context, 0);
  1172. return JSON.stringify(value, null, 4);
  1173. }
  1174. _context() {
  1175. const context = {};
  1176. context.state = null;
  1177. context.index = 0;
  1178. context.count = 0;
  1179. context.dataType = this._type.dataType;
  1180. context.dimensions = this._type.shape.dimensions;
  1181. if (!this._data) {
  1182. context.state = 'Tensor data is empty.';
  1183. return context;
  1184. }
  1185. switch (context.dataType) {
  1186. case 'float32':
  1187. context.data = this._data;
  1188. break;
  1189. case 'float16':
  1190. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  1191. break;
  1192. case 'uint8':
  1193. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  1194. break;
  1195. default:
  1196. if (this._quantization) {
  1197. context.dataType = 'quantization';
  1198. context.bits = this._quantization.numberOfBits.toNumber();
  1199. context.data = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  1200. }
  1201. else {
  1202. context.state = 'Tensor data type is not implemented.';
  1203. }
  1204. break;
  1205. }
  1206. return context;
  1207. }
  1208. _decode(context, dimension) {
  1209. const results = [];
  1210. const size = context.dimensions[dimension];
  1211. if (dimension == context.dimensions.length - 1) {
  1212. for (let i = 0; i < size; i++) {
  1213. if (context.count > context.limit) {
  1214. results.push('...');
  1215. return results;
  1216. }
  1217. switch (context.dataType) {
  1218. case 'float32':
  1219. results.push(this._data[context.index]);
  1220. context.index++;
  1221. break;
  1222. case 'float16':
  1223. results.push(context.data.getFloat16(context.index, true));
  1224. context.index += 2;
  1225. break;
  1226. case 'uint8':
  1227. results.push(context.data.getUint8(context.index, true));
  1228. context.index += 1;
  1229. break;
  1230. case 'quantization':
  1231. results.push(context.data.getBits(context.index, context.bits));
  1232. context.index++;
  1233. break;
  1234. default:
  1235. break;
  1236. }
  1237. context.count++;
  1238. }
  1239. }
  1240. else {
  1241. for (let j = 0; j < size; j++) {
  1242. if (context.count > context.limit) {
  1243. results.push('...');
  1244. return results;
  1245. }
  1246. results.push(this._decode(context, dimension + 1));
  1247. }
  1248. }
  1249. return results;
  1250. }
  1251. };
  1252. coreml.TensorType = class {
  1253. constructor(dataType, shape) {
  1254. this._dataType = dataType;
  1255. this._shape = shape || new coreml.TensorShape([]);
  1256. }
  1257. get dataType() {
  1258. return this._dataType;
  1259. }
  1260. get shape() {
  1261. return this._shape;
  1262. }
  1263. toString() {
  1264. return this.dataType + this._shape.toString();
  1265. }
  1266. };
  1267. coreml.TensorShape = class {
  1268. constructor(dimensions) {
  1269. this._dimensions = dimensions;
  1270. }
  1271. get dimensions() {
  1272. return this._dimensions;
  1273. }
  1274. toString() {
  1275. if (!this._dimensions || this._dimensions.length == 0) {
  1276. return '';
  1277. }
  1278. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  1279. }
  1280. };
  1281. coreml.ListType = class {
  1282. constructor(elementType) {
  1283. this._elementType = elementType;
  1284. }
  1285. toString() {
  1286. return 'list<' + this._elementType.toString() + '>';
  1287. }
  1288. };
  1289. coreml.MapType = class {
  1290. constructor(keyType, valueType) {
  1291. this._keyType = keyType;
  1292. this._valueType = valueType;
  1293. }
  1294. get keyType() {
  1295. return this._keyType;
  1296. }
  1297. get valueType() {
  1298. return this._valueType;
  1299. }
  1300. toString() {
  1301. return 'map<' + this._keyType + ',' + this._valueType.toString() + '>';
  1302. }
  1303. };
  1304. coreml.SequenceType = class {
  1305. constructor(type) {
  1306. this._type = type;
  1307. }
  1308. get type() {
  1309. return this._type;
  1310. }
  1311. toString() {
  1312. return 'sequence<' + this._type + '>';
  1313. }
  1314. };
  1315. coreml.ImageType = class {
  1316. constructor(colorSpace, width, height) {
  1317. this._width = width;
  1318. this._height = height;
  1319. switch (colorSpace) {
  1320. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE:
  1321. this._colorSpace = 'grayscale';
  1322. break;
  1323. case coreml.proto.ImageFeatureType.ColorSpace.RGB:
  1324. this._colorSpace = 'RGB';
  1325. break;
  1326. case coreml.proto.ImageFeatureType.ColorSpace.BGR:
  1327. this._colorSpace = 'BGR';
  1328. break;
  1329. case coreml.proto.ImageFeatureType.ColorSpace.GRAYSCALE_FLOAT16:
  1330. this._colorSpace = 'grayscale:float16';
  1331. break;
  1332. default:
  1333. throw new coreml.Error("Unsupported image color space '" + colorSpace + "'.");
  1334. }
  1335. }
  1336. toString() {
  1337. return 'image<' + this._colorSpace + ',' + this._width. toString() + 'x' + this._height.toString() + '>';
  1338. }
  1339. };
  1340. coreml.OptionalType = class {
  1341. constructor(type) {
  1342. this._type = type;
  1343. }
  1344. get type() {
  1345. return this._type;
  1346. }
  1347. toString() {
  1348. return 'optional<' + this._type.toString() + '>';
  1349. }
  1350. };
  1351. coreml.Utility = class {
  1352. static enum(name, value) {
  1353. let type = coreml.proto;
  1354. const parts = name.split('.');
  1355. while (type && parts.length > 0) {
  1356. type = type[parts.shift()];
  1357. }
  1358. if (type) {
  1359. coreml.Utility._enumKeyMap = coreml.Utility._enumKeyMap || new Map();
  1360. if (!coreml.Utility._enumKeyMap.has(name)) {
  1361. const map = new Map(Object.entries(type).map((pair) => [ pair[1], pair[0] ]));
  1362. coreml.Utility._enumKeyMap.set(name, map);
  1363. }
  1364. const map = coreml.Utility._enumKeyMap.get(name);
  1365. if (map.has(value)) {
  1366. return map.get(value);
  1367. }
  1368. }
  1369. return value;
  1370. }
  1371. static featureType(type) {
  1372. let result = '?';
  1373. if (type) {
  1374. switch (type.Type) {
  1375. case 'multiArrayType': {
  1376. let shape = new coreml.TensorShape([]);
  1377. if (type.multiArrayType.shape && type.multiArrayType.shape.length > 0) {
  1378. shape = new coreml.TensorShape(type.multiArrayType.shape);
  1379. }
  1380. let dataType = '?';
  1381. const ArrayDataType = coreml.proto.ArrayFeatureType.ArrayDataType;
  1382. switch (type.multiArrayType.dataType) {
  1383. case ArrayDataType.INVALID_ARRAY_DATA_TYPE:
  1384. dataType = '?';
  1385. break;
  1386. case ArrayDataType.FLOAT16:
  1387. dataType = 'float16';
  1388. break;
  1389. case ArrayDataType.FLOAT32:
  1390. dataType = 'float32';
  1391. break;
  1392. case ArrayDataType.DOUBLE:
  1393. dataType = 'float64';
  1394. break;
  1395. case ArrayDataType.INT32:
  1396. dataType = 'int32';
  1397. break;
  1398. default:
  1399. throw new coreml.Error("Unsupported array data type '" + type.multiArrayType.dataType + "'.");
  1400. }
  1401. result = new coreml.TensorType(dataType, shape);
  1402. break;
  1403. }
  1404. case 'stringType': {
  1405. result = new coreml.TensorType('string');
  1406. break;
  1407. }
  1408. case 'doubleType': {
  1409. result = new coreml.TensorType('float64');
  1410. break;
  1411. }
  1412. case 'int64Type': {
  1413. result = new coreml.TensorType('int64');
  1414. break;
  1415. }
  1416. case 'dictionaryType': {
  1417. result = new coreml.MapType(type.dictionaryType.KeyType.replace('KeyType', ''), 'float64');
  1418. break;
  1419. }
  1420. case 'sequenceType': {
  1421. result = new coreml.SequenceType(coreml.Utility.featureType(type[type.Type]));
  1422. break;
  1423. }
  1424. case 'imageType': {
  1425. result = new coreml.ImageType(type.imageType.colorSpace, type.imageType.width, type.imageType.height);
  1426. break;
  1427. }
  1428. default: {
  1429. throw new coreml.Error("Unsupported feature type '" + type.Type + "'.");
  1430. }
  1431. }
  1432. if (type.isOptional) {
  1433. result = new coreml.OptionalType(result);
  1434. }
  1435. }
  1436. return result;
  1437. }
  1438. static tensorType(type) {
  1439. if (!coreml.Utility._dataTypes) {
  1440. coreml.Utility._dataTypes = new Map();
  1441. const DataType = coreml.proto.MILSpec.DataType;
  1442. for (const pair of Object.entries(DataType)) {
  1443. if (pair[0] === 'UNUSED_TYPE') {
  1444. continue;
  1445. }
  1446. const name = pair[0] === 'bool' ? 'boolean' : pair[0].toLowerCase();
  1447. coreml.Utility._dataTypes.set(pair[1], name);
  1448. }
  1449. }
  1450. const shape = (type.dimensions.map(dim => dim.constant ? dim.constant.size : '?'));
  1451. const dataType = coreml.Utility._dataTypes.get(type.dataType);
  1452. if (dataType === null) {
  1453. throw new coreml.Error("Unsupported data type '" + type.dataType + "'.");
  1454. }
  1455. return new coreml.TensorType(dataType, new coreml.TensorShape(shape));
  1456. }
  1457. static valueType(type) {
  1458. switch (type.type) {
  1459. case 'tensorType':
  1460. return coreml.Utility.tensorType(type.tensorType);
  1461. case 'listType':
  1462. return new coreml.ListType(coreml.Utility.valueType(type.listType.type));
  1463. default:
  1464. throw new coreml.Error("Unsupported value type '" + type.type + "'.");
  1465. }
  1466. }
  1467. };
  1468. coreml.Metadata = class {
  1469. static open(context) {
  1470. if (coreml.Metadata._metadata) {
  1471. return Promise.resolve(coreml.Metadata._metadata);
  1472. }
  1473. return context.request('coreml-metadata.json', 'utf-8', null).then((data) => {
  1474. coreml.Metadata._metadata = new coreml.Metadata(data);
  1475. return coreml.Metadata._metadata;
  1476. }).catch(() => {
  1477. coreml.Metadata._metadata = new coreml.Metadata(null);
  1478. return coreml.Metadata._metadata;
  1479. });
  1480. }
  1481. constructor(data) {
  1482. this._map = new Map();
  1483. this._attributeCache = new Map();
  1484. this._inputCache = new Map();
  1485. if (data) {
  1486. const metadata = JSON.parse(data);
  1487. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  1488. }
  1489. }
  1490. type(name) {
  1491. return this._map.get(name);
  1492. }
  1493. attribute(type, name) {
  1494. const key = type + ':' + name;
  1495. if (!this._attributeCache.has(key)) {
  1496. this._attributeCache.set(key, null);
  1497. const metadata = this.type(type);
  1498. if (metadata && Array.isArray(metadata.attributes) && metadata.attributes.length > 0) {
  1499. for (const attribute of metadata.attributes) {
  1500. this._attributeCache.set(type + ':' + attribute.name, attribute);
  1501. }
  1502. }
  1503. }
  1504. return this._attributeCache.get(key);
  1505. }
  1506. visible(type, name) {
  1507. const key = type + ':' + name;
  1508. if (!this._inputCache.has(key)) {
  1509. this._inputCache.set(key, null);
  1510. const metadata = this.type(type);
  1511. if (metadata && Array.isArray(metadata.inputs) && metadata.inputs.length > 0) {
  1512. for (const input of metadata.inputs) {
  1513. this._inputCache.set(type + ':' + input.name, input);
  1514. }
  1515. }
  1516. }
  1517. const input = this._inputCache.get(key);
  1518. if (input) {
  1519. return input.visible === false ? false : true;
  1520. }
  1521. return true;
  1522. }
  1523. getInputs(type, inputs) {
  1524. const results = [];
  1525. const schema = this._map.get(type);
  1526. let index = 0;
  1527. while (index < inputs.length) {
  1528. const result = { arguments: [] };
  1529. let count = 1;
  1530. let name = null;
  1531. if (schema && schema.inputs) {
  1532. if (index < schema.inputs.length) {
  1533. const input = schema.inputs[index];
  1534. name = input.name;
  1535. if (schema.inputs[index].option == 'variadic') {
  1536. count = inputs.length - index;
  1537. }
  1538. }
  1539. }
  1540. else if (index == 0) {
  1541. name = 'input';
  1542. }
  1543. result.name = name ? name : '(' + index.toString() + ')';
  1544. const array = inputs.slice(index, index + count);
  1545. for (let j = 0; j < array.length; j++) {
  1546. result.arguments.push({ name: array[j] });
  1547. }
  1548. index += count;
  1549. results.push(result);
  1550. }
  1551. return results;
  1552. }
  1553. getOutputName(type, index) {
  1554. const schema = this._map.get(type);
  1555. if (schema) {
  1556. const outputs = schema.outputs;
  1557. if (outputs && index < outputs.length) {
  1558. const output = outputs[index];
  1559. if (output) {
  1560. const name = output.name;
  1561. if (name) {
  1562. return name;
  1563. }
  1564. }
  1565. }
  1566. }
  1567. if (index == 0) {
  1568. return 'output';
  1569. }
  1570. return '(' + index.toString() + ')';
  1571. }
  1572. };
  1573. coreml.Error = class extends Error {
  1574. constructor(message) {
  1575. super(message);
  1576. this.name = 'Error loading Core ML model.';
  1577. }
  1578. };
  1579. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1580. module.exports.ModelFactory = coreml.ModelFactory;
  1581. }