sklearn.js 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var sklearn = sklearn || {};
  4. sklearn.ModelFactory = class {
  5. match(context) {
  6. const obj = context.open('pkl');
  7. const validate = (obj) => {
  8. if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
  9. const key = obj.__class__.__module__ + '.' + obj.__class__.__name__;
  10. return key.startsWith('sklearn.') || key.startsWith('xgboost.sklearn.') || key.startsWith('lightgbm.sklearn.');
  11. }
  12. return false;
  13. };
  14. if (validate(obj)) {
  15. return true;
  16. }
  17. if (Array.isArray(obj) && obj.every((item) => validate(item))) {
  18. return true;
  19. }
  20. return false;
  21. }
  22. open(context) {
  23. return sklearn.Metadata.open(context).then((metadata) => {
  24. const obj = context.open('pkl');
  25. return new sklearn.Model(metadata, obj);
  26. });
  27. }
  28. };
  29. sklearn.Model = class {
  30. constructor(metadata, obj) {
  31. this._format = 'scikit-learn';
  32. this._graphs = [];
  33. if (!Array.isArray(obj)) {
  34. this._format += obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '';
  35. this._graphs.push(new sklearn.Graph(metadata, '', obj));
  36. }
  37. else {
  38. for (let i = 0; i < obj.length; i++) {
  39. this._graphs.push(new sklearn.Graph(metadata, i.toString(), obj[i]));
  40. }
  41. }
  42. }
  43. get format() {
  44. return this._format;
  45. }
  46. get graphs() {
  47. return this._graphs;
  48. }
  49. };
  50. sklearn.Graph = class {
  51. constructor(metadata, name, obj) {
  52. this._name = name || '';
  53. this._metadata = metadata;
  54. this._nodes = [];
  55. this._groups = false;
  56. this._process('', '', obj, ['data']);
  57. }
  58. _process(group, name, obj, inputs) {
  59. const type = obj.__class__.__module__ + '.' + obj.__class__.__name__;
  60. switch (type) {
  61. case 'sklearn.pipeline.Pipeline': {
  62. this._groups = true;
  63. name = name || 'pipeline';
  64. const childGroup = this._concat(group, name);
  65. for (const step of obj.steps) {
  66. inputs = this._process(childGroup, step[0], step[1], inputs);
  67. }
  68. return inputs;
  69. }
  70. case 'sklearn.pipeline.FeatureUnion': {
  71. this._groups = true;
  72. const outputs = [];
  73. name = name || 'union';
  74. const output = this._concat(group, name);
  75. const subgroup = this._concat(group, name);
  76. this._add(subgroup, output, obj, inputs, [ output ]);
  77. for (const transformer of obj.transformer_list){
  78. outputs.push(...this._process(subgroup, transformer[0], transformer[1], [ output ]));
  79. }
  80. return outputs;
  81. }
  82. case 'sklearn.compose._column_transformer.ColumnTransformer': {
  83. this._groups = true;
  84. name = name || 'transformer';
  85. const output = this._concat(group, name);
  86. const subgroup = this._concat(group, name);
  87. const outputs = [];
  88. this._add(subgroup, output, obj, inputs, [ output ]);
  89. for (const transformer of obj.transformers){
  90. outputs.push(...this._process(subgroup, transformer[0], transformer[1], [ output ]));
  91. }
  92. return outputs;
  93. }
  94. default: {
  95. const output = this._concat(group, name);
  96. this._add(group, output, obj, inputs, output === '' ? [] : [ output ]);
  97. return [ output ];
  98. }
  99. }
  100. }
  101. _add(group, name, obj, inputs, outputs) {
  102. const initializers = [];
  103. for (const key of Object.keys(obj)) {
  104. if (!key.startsWith('_')) {
  105. const value = obj[key];
  106. if (sklearn.Utility.isTensor(value)) {
  107. initializers.push(new sklearn.Tensor(key, value));
  108. }
  109. }
  110. }
  111. inputs = inputs.map((input) => {
  112. return new sklearn.Parameter(input, [ new sklearn.Argument(input, null, null) ]);
  113. });
  114. inputs.push(...initializers.map((initializer) => {
  115. return new sklearn.Parameter(initializer.name, [ new sklearn.Argument('', null, initializer) ]);
  116. }));
  117. outputs = outputs.map((output) => {
  118. return new sklearn.Parameter(output, [ new sklearn.Argument(output, null, null) ]);
  119. });
  120. this._nodes.push(new sklearn.Node(this._metadata, group, name, obj, inputs, outputs));
  121. }
  122. _concat(parent, name){
  123. return (parent === '' ? name : `${parent}/${name}`);
  124. }
  125. get name() {
  126. return this._name;
  127. }
  128. get groups() {
  129. return this._groups;
  130. }
  131. get inputs() {
  132. return [];
  133. }
  134. get outputs() {
  135. return [];
  136. }
  137. get nodes() {
  138. return this._nodes;
  139. }
  140. };
  141. sklearn.Parameter = class {
  142. constructor(name, args) {
  143. this._name = name;
  144. this._arguments = args;
  145. }
  146. get name() {
  147. return this._name;
  148. }
  149. get visible() {
  150. return true;
  151. }
  152. get arguments() {
  153. return this._arguments;
  154. }
  155. };
  156. sklearn.Argument = class {
  157. constructor(name, type, initializer) {
  158. if (typeof name !== 'string') {
  159. throw new sklearn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  160. }
  161. this._name = name;
  162. this._type = type || null;
  163. this._initializer = initializer || null;
  164. }
  165. get name() {
  166. return this._name;
  167. }
  168. get type() {
  169. if (this._initializer) {
  170. return this._initializer.type;
  171. }
  172. return this._type;
  173. }
  174. get initializer() {
  175. return this._initializer;
  176. }
  177. };
  178. sklearn.Node = class {
  179. constructor(metadata, group, name, obj, inputs, outputs) {
  180. this._group = group || '';
  181. this._name = name || '';
  182. const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : 'Object';
  183. this._type = metadata.type(type) || { name: type };
  184. this._inputs = inputs;
  185. this._outputs = outputs;
  186. this._attributes = [];
  187. this._initializers = [];
  188. for (const name of Object.keys(obj)) {
  189. if (!name.startsWith('_')) {
  190. const value = obj[name];
  191. if (value && !Array.isArray(value) && value === Object(value) && sklearn.Utility.isTensor(value)) {
  192. this._initializers.push(new sklearn.Tensor(name, value));
  193. }
  194. else {
  195. const schema = metadata.attribute(this._type, name);
  196. this._attributes.push(new sklearn.Attribute(schema, name, value));
  197. }
  198. }
  199. }
  200. }
  201. get type() {
  202. return this._type; // .split('.').pop();
  203. }
  204. get name() {
  205. return this._name;
  206. }
  207. get group() {
  208. return this._group ? this._group : null;
  209. }
  210. get inputs() {
  211. return this._inputs;
  212. }
  213. get outputs() {
  214. return this._outputs;
  215. }
  216. get attributes() {
  217. return this._attributes;
  218. }
  219. };
  220. sklearn.Attribute = class {
  221. constructor(schema, name, value) {
  222. this._name = name;
  223. this._value = value;
  224. if (schema) {
  225. if (Object.prototype.hasOwnProperty.call(schema, 'option') && schema.option == 'optional' && this._value == null) {
  226. this._visible = false;
  227. }
  228. else if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  229. this._visible = false;
  230. }
  231. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  232. if (sklearn.Attribute._isEquivalent(schema.default, this._value)) {
  233. this._visible = false;
  234. }
  235. }
  236. }
  237. if (value && value.__class__) {
  238. this._type = value.__class__.__module__ + '.' + value.__class__.__name__;
  239. }
  240. }
  241. get name() {
  242. return this._name;
  243. }
  244. get value() {
  245. return this._value;
  246. }
  247. get type() {
  248. return this._type;
  249. }
  250. get visible() {
  251. return this._visible == false ? false : true;
  252. }
  253. static _isEquivalent(a, b) {
  254. if (a === b) {
  255. return a !== 0 || 1 / a === 1 / b;
  256. }
  257. if (a == null || b == null) {
  258. return false;
  259. }
  260. if (a !== a) {
  261. return b !== b;
  262. }
  263. const type = typeof a;
  264. if (type !== 'function' && type !== 'object' && typeof b != 'object') {
  265. return false;
  266. }
  267. const className = toString.call(a);
  268. if (className !== toString.call(b)) {
  269. return false;
  270. }
  271. switch (className) {
  272. case '[object RegExp]':
  273. case '[object String]':
  274. return '' + a === '' + b;
  275. case '[object Number]': {
  276. if (+a !== +a) {
  277. return +b !== +b;
  278. }
  279. return +a === 0 ? 1 / +a === 1 / b : +a === +b;
  280. }
  281. case '[object Date]':
  282. case '[object Boolean]': {
  283. return +a === +b;
  284. }
  285. case '[object Array]': {
  286. let length = a.length;
  287. if (length !== b.length) {
  288. return false;
  289. }
  290. while (length--) {
  291. if (!sklearn.Attribute._isEquivalent(a[length], b[length])) {
  292. return false;
  293. }
  294. }
  295. return true;
  296. }
  297. }
  298. const keys = Object.keys(a);
  299. let size = keys.length;
  300. if (Object.keys(b).length != size) {
  301. return false;
  302. }
  303. while (size--) {
  304. const key = keys[size];
  305. if (!(Object.prototype.hasOwnProperty.call(b, key) && sklearn.Attribute._isEquivalent(a[key], b[key]))) {
  306. return false;
  307. }
  308. }
  309. return true;
  310. }
  311. };
  312. sklearn.Tensor = class {
  313. constructor(name, value) {
  314. this._name = name;
  315. if (sklearn.Utility.isTensor(value)) {
  316. this._kind = 'NumPy Array';
  317. this._type = new sklearn.TensorType(value.dtype.name, new sklearn.TensorShape(value.shape));
  318. this._data = value.data;
  319. if (value.dtype.name === 'string') {
  320. this._itemsize = value.dtype.itemsize;
  321. }
  322. }
  323. else {
  324. const type = value.__class__.__module__ + '.' + value.__class__.__name__;
  325. throw new sklearn.Error("Unknown tensor type '" + type + "'.");
  326. }
  327. }
  328. get name() {
  329. return this._name;
  330. }
  331. get type() {
  332. return this._type;
  333. }
  334. get kind() {
  335. return this._kind;
  336. }
  337. get state() {
  338. return this._context().state || null;
  339. }
  340. get value() {
  341. const context = this._context();
  342. if (context.state) {
  343. return null;
  344. }
  345. context.limit = Number.MAX_SAFE_INTEGER;
  346. return this._decode(context, 0);
  347. }
  348. toString() {
  349. const context = this._context();
  350. if (context.state) {
  351. return '';
  352. }
  353. context.limit = 10000;
  354. const value = this._decode(context, 0);
  355. switch (this._type.dataType) {
  356. case 'int64':
  357. case 'uint64':
  358. return sklearn.Tensor._stringify(value, '', ' ');
  359. }
  360. return JSON.stringify(value, null, 4);
  361. }
  362. _context() {
  363. const context = {};
  364. context.index = 0;
  365. context.count = 0;
  366. context.state = null;
  367. if (!this._type) {
  368. context.state = 'Tensor has no data type.';
  369. return context;
  370. }
  371. if (!this._data) {
  372. context.state = 'Tensor is data is empty.';
  373. return context;
  374. }
  375. context.dataType = this._type.dataType;
  376. context.dimensions = this._type.shape.dimensions;
  377. switch (context.dataType) {
  378. case 'float32':
  379. case 'float64':
  380. case 'int32':
  381. case 'uint32':
  382. case 'int64':
  383. case 'uint64':
  384. context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  385. break;
  386. case 'string':
  387. context.data = this._data;
  388. context.itemsize = this._itemsize;
  389. context.decoder = new TextDecoder('utf-8');
  390. break;
  391. default:
  392. context.state = "Tensor data type '" + context.dataType + "' is not implemented.";
  393. return context;
  394. }
  395. return context;
  396. }
  397. _decode(context, dimension) {
  398. const results = [];
  399. const size = context.dimensions[dimension];
  400. if (dimension == context.dimensions.length - 1) {
  401. for (let i = 0; i < size; i++) {
  402. if (context.count > context.limit) {
  403. results.push('...');
  404. return results;
  405. }
  406. switch (context.dataType) {
  407. case 'float32': {
  408. results.push(context.view.getFloat32(context.index, true));
  409. context.index += 4;
  410. context.count++;
  411. break;
  412. }
  413. case 'float64': {
  414. results.push(context.view.getFloat64(context.index, true));
  415. context.index += 8;
  416. context.count++;
  417. break;
  418. }
  419. case 'int32': {
  420. results.push(context.view.getInt32(context.index, true));
  421. context.index += 4;
  422. context.count++;
  423. break;
  424. }
  425. case 'uint32': {
  426. results.push(context.view.getUint32(context.index, true));
  427. context.index += 4;
  428. context.count++;
  429. break;
  430. }
  431. case 'int64': {
  432. results.push(context.view.getInt64(context.index, true));
  433. context.index += 8;
  434. context.count++;
  435. break;
  436. }
  437. case 'uint64': {
  438. results.push(context.view.getUint64(context.index, true));
  439. context.index += 8;
  440. context.count++;
  441. break;
  442. }
  443. case 'string': {
  444. const buffer = context.data.subarray(context.index, context.index + context.itemsize);
  445. const index = buffer.indexOf(0);
  446. const text = context.decoder.decode(index >= 0 ? buffer.subarray(0, index) : buffer);
  447. results.push(text);
  448. context.index += context.itemsize;
  449. context.count++;
  450. break;
  451. }
  452. }
  453. }
  454. }
  455. else {
  456. for (let j = 0; j < size; j++) {
  457. if (context.count > context.limit) {
  458. results.push('...');
  459. return results;
  460. }
  461. results.push(this._decode(context, dimension + 1));
  462. }
  463. }
  464. return results;
  465. }
  466. static _stringify(value, indentation, indent) {
  467. if (Array.isArray(value)) {
  468. const result = [];
  469. result.push('[');
  470. const items = value.map((item) => sklearn.Tensor._stringify(item, indentation + indent, indent));
  471. if (items.length > 0) {
  472. result.push(items.join(',\n'));
  473. }
  474. result.push(']');
  475. return result.join('\n');
  476. }
  477. return indentation + value.toString();
  478. }
  479. };
  480. sklearn.TensorType = class {
  481. constructor(dataType, shape) {
  482. this._dataType = dataType;
  483. this._shape = shape;
  484. }
  485. get dataType() {
  486. return this._dataType;
  487. }
  488. get shape() {
  489. return this._shape;
  490. }
  491. toString() {
  492. return this.dataType + this._shape.toString();
  493. }
  494. };
  495. sklearn.TensorShape = class {
  496. constructor(dimensions) {
  497. this._dimensions = dimensions;
  498. }
  499. get dimensions() {
  500. return this._dimensions;
  501. }
  502. toString() {
  503. return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
  504. }
  505. };
  506. sklearn.Metadata = class {
  507. static open(context) {
  508. if (sklearn.Metadata._metadata) {
  509. return Promise.resolve(sklearn.Metadata._metadata);
  510. }
  511. return context.request('sklearn-metadata.json', 'utf-8', null).then((data) => {
  512. sklearn.Metadata._metadata = new sklearn.Metadata(data);
  513. return sklearn.Metadata._metadata;
  514. }).catch(() => {
  515. sklearn.Metadata._metadata = new sklearn.Metadata(null);
  516. return sklearn.Metadata._metadata;
  517. });
  518. }
  519. constructor(data) {
  520. this._map = new Map();
  521. this._attributeCache = new Map();
  522. if (data) {
  523. const metadata = JSON.parse(data);
  524. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  525. }
  526. }
  527. type(name) {
  528. return this._map.get(name);
  529. }
  530. attribute(type, name) {
  531. const key = type + ':' + name;
  532. if (!this._attributeCache.has(key)) {
  533. const schema = this.type(type);
  534. if (schema && schema.attributes && schema.attributes.length > 0) {
  535. for (const attribute of schema.attributes) {
  536. this._attributeCache.set(type + ':' + attribute.name, attribute);
  537. }
  538. }
  539. if (!this._attributeCache.has(key)) {
  540. this._attributeCache.set(key, null);
  541. }
  542. }
  543. return this._attributeCache.get(key);
  544. }
  545. };
  546. sklearn.Utility = class {
  547. static isTensor(obj) {
  548. return obj && obj.__class__ && obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray';
  549. }
  550. static findWeights(obj) {
  551. const keys = [ '', 'blobs' ];
  552. for (const key of keys) {
  553. const dict = key === '' ? obj : obj[key];
  554. if (dict) {
  555. const weights = new Map();
  556. if (dict instanceof Map) {
  557. for (const pair of dict) {
  558. if (!sklearn.Utility.isTensor(pair[1])) {
  559. return null;
  560. }
  561. weights.set(pair[0], pair[1]);
  562. }
  563. return weights;
  564. }
  565. else if (!Array.isArray(dict)) {
  566. for (const key in dict) {
  567. const value = dict[key];
  568. if (key != 'weight_order' && key != 'lr') {
  569. if (!key || !sklearn.Utility.isTensor(value)) {
  570. return null;
  571. }
  572. weights.set(key, value);
  573. }
  574. }
  575. return weights;
  576. }
  577. }
  578. }
  579. for (const key of keys) {
  580. const list = key === '' ? obj : obj[key];
  581. if (list && Array.isArray(list)) {
  582. const weights = new Map();
  583. for (let i = 0; i < list.length; i++) {
  584. const value = list[i];
  585. if (!sklearn.Utility.isTensor(value, 'numpy.ndarray')) {
  586. return null;
  587. }
  588. weights.set(i.toString(), value);
  589. }
  590. return weights;
  591. }
  592. }
  593. return null;
  594. }
  595. };
  596. sklearn.Error = class extends Error {
  597. constructor(message) {
  598. super(message);
  599. this.name = 'Error loading scikit-learn model.';
  600. }
  601. };
  602. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  603. module.exports.ModelFactory = sklearn.ModelFactory;
  604. }