sklearn.js 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. // Experimental
  2. var sklearn = sklearn || {};
  3. sklearn.ModelFactory = class {
  4. match(context) {
  5. const obj = context.open('pkl');
  6. const validate = (obj, name) => {
  7. if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
  8. const key = obj.__class__.__module__ + '.' + obj.__class__.__name__;
  9. return key.startsWith(name);
  10. }
  11. return false;
  12. };
  13. const formats = [
  14. { name: 'sklearn.', format: 'sklearn' },
  15. { name: 'xgboost.sklearn.', format: 'sklearn' },
  16. { name: 'lightgbm.sklearn.', format: 'sklearn' },
  17. { name: 'scipy.', format: 'scipy' }
  18. ];
  19. for (const format of formats) {
  20. if (validate(obj, format.name)) {
  21. return format.format;
  22. }
  23. if (Array.isArray(obj) && obj.every((item) => validate(item, format.name))) {
  24. return format.format + '.list';
  25. }
  26. if ((Object(obj) === obj) && Object.entries(obj).every((entry) => validate(entry[1], format.name))) {
  27. return format.format + '.map';
  28. }
  29. }
  30. return undefined;
  31. }
  32. open(context, match) {
  33. return sklearn.Metadata.open(context).then((metadata) => {
  34. const obj = context.open('pkl');
  35. return new sklearn.Model(metadata, match, obj);
  36. });
  37. }
  38. };
  39. sklearn.Model = class {
  40. constructor(metadata, match, obj) {
  41. const formats = new Map([ [ 'sklearn', 'scikit-learn' ], [ 'scipy', 'SciPy' ] ]);
  42. this._format = formats.get(match.split('.').shift());
  43. this._graphs = [];
  44. const version = [];
  45. switch (match) {
  46. case 'sklearn':
  47. case 'scipy': {
  48. version.push(obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
  49. this._graphs.push(new sklearn.Graph(metadata, '', obj));
  50. break;
  51. }
  52. case 'sklearn.list':
  53. case 'scipy.list': {
  54. const list = obj;
  55. for (let i = 0; i < list.length; i++) {
  56. const obj = list[i];
  57. this._graphs.push(new sklearn.Graph(metadata, i.toString(), obj));
  58. version.push(obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
  59. }
  60. break;
  61. }
  62. case 'sklearn.map':
  63. case 'scipy.map': {
  64. for (const entry of Object.entries(obj)) {
  65. const obj = entry[1];
  66. this._graphs.push(new sklearn.Graph(metadata, entry[0], obj));
  67. version.push(obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
  68. }
  69. break;
  70. }
  71. default: {
  72. throw new sklearn.Error("Unsupported scikit-learn format '" + match + "'.");
  73. }
  74. }
  75. if (version.every((value) => value === version[0])) {
  76. this._format += version[0];
  77. }
  78. }
  79. get format() {
  80. return this._format;
  81. }
  82. get graphs() {
  83. return this._graphs;
  84. }
  85. };
  86. sklearn.Graph = class {
  87. constructor(metadata, name, obj) {
  88. this._name = name || '';
  89. this._metadata = metadata;
  90. this._nodes = [];
  91. this._groups = false;
  92. this._process('', '', obj, ['data']);
  93. }
  94. _process(group, name, obj, inputs) {
  95. const type = obj.__class__.__module__ + '.' + obj.__class__.__name__;
  96. switch (type) {
  97. case 'sklearn.pipeline.Pipeline': {
  98. this._groups = true;
  99. name = name || 'pipeline';
  100. const childGroup = this._concat(group, name);
  101. for (const step of obj.steps) {
  102. inputs = this._process(childGroup, step[0], step[1], inputs);
  103. }
  104. return inputs;
  105. }
  106. case 'sklearn.pipeline.FeatureUnion': {
  107. this._groups = true;
  108. const outputs = [];
  109. name = name || 'union';
  110. const output = this._concat(group, name);
  111. const subgroup = this._concat(group, name);
  112. this._nodes.push(new sklearn.Node(this._metadata, subgroup, output, obj, inputs, [ output ]));
  113. for (const transformer of obj.transformer_list){
  114. outputs.push(...this._process(subgroup, transformer[0], transformer[1], [ output ]));
  115. }
  116. return outputs;
  117. }
  118. case 'sklearn.compose._column_transformer.ColumnTransformer': {
  119. this._groups = true;
  120. name = name || 'transformer';
  121. const output = this._concat(group, name);
  122. const subgroup = this._concat(group, name);
  123. const outputs = [];
  124. this._nodes.push(new sklearn.Node(this._metadata, subgroup, output, obj, inputs, [ output ]));
  125. for (const transformer of obj.transformers){
  126. if (transformer[1] !== 'passthrough') {
  127. outputs.push(...this._process(subgroup, transformer[0], transformer[1], [ output ]));
  128. }
  129. }
  130. return outputs;
  131. }
  132. default: {
  133. const output = this._concat(group, name);
  134. this._nodes.push(new sklearn.Node(this._metadata, group, output, obj, inputs, output === '' ? [] : [ output ]));
  135. return [ output ];
  136. }
  137. }
  138. }
  139. _concat(parent, name){
  140. return (parent === '' ? name : `${parent}/${name}`);
  141. }
  142. get name() {
  143. return this._name;
  144. }
  145. get groups() {
  146. return this._groups;
  147. }
  148. get inputs() {
  149. return [];
  150. }
  151. get outputs() {
  152. return [];
  153. }
  154. get nodes() {
  155. return this._nodes;
  156. }
  157. };
  158. sklearn.Parameter = class {
  159. constructor(name, args) {
  160. this._name = name;
  161. this._arguments = args;
  162. }
  163. get name() {
  164. return this._name;
  165. }
  166. get visible() {
  167. return true;
  168. }
  169. get arguments() {
  170. return this._arguments;
  171. }
  172. };
  173. sklearn.Argument = class {
  174. constructor(name, type, initializer) {
  175. if (typeof name !== 'string') {
  176. throw new sklearn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  177. }
  178. this._name = name;
  179. this._type = type || null;
  180. this._initializer = initializer || null;
  181. }
  182. get name() {
  183. return this._name;
  184. }
  185. get type() {
  186. if (this._initializer) {
  187. return this._initializer.type;
  188. }
  189. return this._type;
  190. }
  191. get initializer() {
  192. return this._initializer;
  193. }
  194. };
  195. sklearn.Node = class {
  196. constructor(metadata, group, name, obj, inputs, outputs) {
  197. this._group = group || '';
  198. this._name = name || '';
  199. const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : 'Object';
  200. this._type = metadata.type(type) || { name: type };
  201. this._inputs = inputs.map((input) => new sklearn.Parameter(input, [ new sklearn.Argument(input, null, null) ]));
  202. this._outputs = outputs.map((output) => new sklearn.Parameter(output, [ new sklearn.Argument(output, null, null) ]));
  203. this._attributes = [];
  204. for (const entry of Object.entries(obj)) {
  205. const name = entry[0];
  206. const value = entry[1];
  207. if (value && sklearn.Utility.isTensor(value)) {
  208. const argument = new sklearn.Argument('', null, new sklearn.Tensor(value));
  209. const paramter = new sklearn.Parameter(name, [ argument ]);
  210. this._inputs.push(paramter);
  211. }
  212. else if (Array.isArray(value) && value.every((obj) => sklearn.Utility.isTensor(obj))) {
  213. const args = value.map((obj) => new sklearn.Argument('', null, new sklearn.Tensor(obj)));
  214. const paramter = new sklearn.Parameter(name, args);
  215. this._inputs.push(paramter);
  216. }
  217. else if (!name.startsWith('_')) {
  218. this._attributes.push(new sklearn.Attribute(metadata.attribute(this._type, name), name, value));
  219. }
  220. }
  221. }
  222. get type() {
  223. return this._type; // .split('.').pop();
  224. }
  225. get name() {
  226. return this._name;
  227. }
  228. get group() {
  229. return this._group ? this._group : null;
  230. }
  231. get inputs() {
  232. return this._inputs;
  233. }
  234. get outputs() {
  235. return this._outputs;
  236. }
  237. get attributes() {
  238. return this._attributes;
  239. }
  240. };
  241. sklearn.Attribute = class {
  242. constructor(metadata, name, value) {
  243. this._name = name;
  244. this._value = value;
  245. if (metadata) {
  246. if (metadata.option === 'optional' && this._value == null) {
  247. this._visible = false;
  248. }
  249. else if (metadata.visible === false) {
  250. this._visible = false;
  251. }
  252. else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  253. if (sklearn.Attribute._isEquivalent(metadata.default, this._value)) {
  254. this._visible = false;
  255. }
  256. }
  257. }
  258. if (value) {
  259. if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj.__class__ && obj.__class__.__module__ === value[0].__class__.__module__ && obj.__class__.__name__ === value[0].__class__.__name__)) {
  260. this._type = value[0].__class__.__module__ + '.' + value[0].__class__.__name__ + '[]';
  261. }
  262. else if (value.__class__) {
  263. this._type = value.__class__.__module__ + '.' + value.__class__.__name__;
  264. }
  265. }
  266. }
  267. get name() {
  268. return this._name;
  269. }
  270. get value() {
  271. return this._value;
  272. }
  273. get type() {
  274. return this._type;
  275. }
  276. get visible() {
  277. return this._visible == false ? false : true;
  278. }
  279. static _isEquivalent(a, b) {
  280. if (a === b) {
  281. return a !== 0 || 1 / a === 1 / b;
  282. }
  283. if (a == null || b == null) {
  284. return false;
  285. }
  286. if (a !== a) {
  287. return b !== b;
  288. }
  289. const type = typeof a;
  290. if (type !== 'function' && type !== 'object' && typeof b != 'object') {
  291. return false;
  292. }
  293. const className = toString.call(a);
  294. if (className !== toString.call(b)) {
  295. return false;
  296. }
  297. switch (className) {
  298. case '[object RegExp]':
  299. case '[object String]':
  300. return '' + a === '' + b;
  301. case '[object Number]': {
  302. if (+a !== +a) {
  303. return +b !== +b;
  304. }
  305. return +a === 0 ? 1 / +a === 1 / b : +a === +b;
  306. }
  307. case '[object Date]':
  308. case '[object Boolean]': {
  309. return +a === +b;
  310. }
  311. case '[object Array]': {
  312. let length = a.length;
  313. if (length !== b.length) {
  314. return false;
  315. }
  316. while (length--) {
  317. if (!sklearn.Attribute._isEquivalent(a[length], b[length])) {
  318. return false;
  319. }
  320. }
  321. return true;
  322. }
  323. }
  324. const keys = Object.keys(a);
  325. let size = keys.length;
  326. if (Object.keys(b).length != size) {
  327. return false;
  328. }
  329. while (size--) {
  330. const key = keys[size];
  331. if (!(Object.prototype.hasOwnProperty.call(b, key) && sklearn.Attribute._isEquivalent(a[key], b[key]))) {
  332. return false;
  333. }
  334. }
  335. return true;
  336. }
  337. };
  338. sklearn.Tensor = class {
  339. constructor(value) {
  340. if (!sklearn.Utility.isTensor(value)) {
  341. const type = value.__class__.__module__ + '.' + value.__class__.__name__;
  342. throw new sklearn.Error("Unknown tensor type '" + type + "'.");
  343. }
  344. this._kind = 'NumPy Array';
  345. this._type = new sklearn.TensorType(value.dtype.name, new sklearn.TensorShape(value.shape));
  346. this._data = value.data;
  347. if (value.dtype.name === 'string') {
  348. this._itemsize = value.dtype.itemsize;
  349. }
  350. }
  351. get type() {
  352. return this._type;
  353. }
  354. get kind() {
  355. return this._kind;
  356. }
  357. get state() {
  358. return this._context().state || null;
  359. }
  360. get value() {
  361. const context = this._context();
  362. if (context.state) {
  363. return null;
  364. }
  365. context.limit = Number.MAX_SAFE_INTEGER;
  366. return this._decode(context, 0);
  367. }
  368. toString() {
  369. const context = this._context();
  370. if (context.state) {
  371. return '';
  372. }
  373. context.limit = 10000;
  374. const value = this._decode(context, 0);
  375. switch (this._type.dataType) {
  376. case 'int64':
  377. case 'uint64':
  378. return sklearn.Tensor._stringify(value, '', ' ');
  379. }
  380. return JSON.stringify(value, null, 4);
  381. }
  382. _context() {
  383. const context = {};
  384. context.index = 0;
  385. context.count = 0;
  386. context.state = null;
  387. if (!this._type) {
  388. context.state = 'Tensor has no data type.';
  389. return context;
  390. }
  391. if (!this._data) {
  392. context.state = 'Tensor is data is empty.';
  393. return context;
  394. }
  395. context.dataType = this._type.dataType;
  396. context.dimensions = this._type.shape.dimensions;
  397. switch (context.dataType) {
  398. case 'float32':
  399. case 'float64':
  400. case 'int32':
  401. case 'uint32':
  402. case 'int64':
  403. case 'uint64':
  404. context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  405. break;
  406. case 'string':
  407. context.data = this._data;
  408. context.itemsize = this._itemsize;
  409. context.decoder = new TextDecoder('utf-8');
  410. break;
  411. default:
  412. context.state = "Tensor data type '" + context.dataType + "' is not implemented.";
  413. return context;
  414. }
  415. return context;
  416. }
  417. _decode(context, dimension) {
  418. const results = [];
  419. const size = context.dimensions[dimension];
  420. if (dimension == context.dimensions.length - 1) {
  421. for (let i = 0; i < size; i++) {
  422. if (context.count > context.limit) {
  423. results.push('...');
  424. return results;
  425. }
  426. switch (context.dataType) {
  427. case 'float32': {
  428. results.push(context.view.getFloat32(context.index, true));
  429. context.index += 4;
  430. context.count++;
  431. break;
  432. }
  433. case 'float64': {
  434. results.push(context.view.getFloat64(context.index, true));
  435. context.index += 8;
  436. context.count++;
  437. break;
  438. }
  439. case 'int32': {
  440. results.push(context.view.getInt32(context.index, true));
  441. context.index += 4;
  442. context.count++;
  443. break;
  444. }
  445. case 'uint32': {
  446. results.push(context.view.getUint32(context.index, true));
  447. context.index += 4;
  448. context.count++;
  449. break;
  450. }
  451. case 'int64': {
  452. results.push(context.view.getInt64(context.index, true));
  453. context.index += 8;
  454. context.count++;
  455. break;
  456. }
  457. case 'uint64': {
  458. results.push(context.view.getUint64(context.index, true));
  459. context.index += 8;
  460. context.count++;
  461. break;
  462. }
  463. case 'string': {
  464. const buffer = context.data.subarray(context.index, context.index + context.itemsize);
  465. const index = buffer.indexOf(0);
  466. const content = context.decoder.decode(index >= 0 ? buffer.subarray(0, index) : buffer);
  467. results.push(content);
  468. context.index += context.itemsize;
  469. context.count++;
  470. break;
  471. }
  472. }
  473. }
  474. }
  475. else {
  476. for (let j = 0; j < size; j++) {
  477. if (context.count > context.limit) {
  478. results.push('...');
  479. return results;
  480. }
  481. results.push(this._decode(context, dimension + 1));
  482. }
  483. }
  484. return results;
  485. }
  486. static _stringify(value, indentation, indent) {
  487. if (Array.isArray(value)) {
  488. const result = [];
  489. result.push('[');
  490. const items = value.map((item) => sklearn.Tensor._stringify(item, indentation + indent, indent));
  491. if (items.length > 0) {
  492. result.push(items.join(',\n'));
  493. }
  494. result.push(']');
  495. return result.join('\n');
  496. }
  497. return indentation + value.toString();
  498. }
  499. };
  500. sklearn.TensorType = class {
  501. constructor(dataType, shape) {
  502. this._dataType = dataType;
  503. this._shape = shape;
  504. }
  505. get dataType() {
  506. return this._dataType;
  507. }
  508. get shape() {
  509. return this._shape;
  510. }
  511. toString() {
  512. return this.dataType + this._shape.toString();
  513. }
  514. };
  515. sklearn.TensorShape = class {
  516. constructor(dimensions) {
  517. this._dimensions = dimensions;
  518. }
  519. get dimensions() {
  520. return this._dimensions;
  521. }
  522. toString() {
  523. return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
  524. }
  525. };
  526. sklearn.Metadata = class {
  527. static open(context) {
  528. if (sklearn.Metadata._metadata) {
  529. return Promise.resolve(sklearn.Metadata._metadata);
  530. }
  531. return context.request('sklearn-metadata.json', 'utf-8', null).then((data) => {
  532. sklearn.Metadata._metadata = new sklearn.Metadata(data);
  533. return sklearn.Metadata._metadata;
  534. }).catch(() => {
  535. sklearn.Metadata._metadata = new sklearn.Metadata(null);
  536. return sklearn.Metadata._metadata;
  537. });
  538. }
  539. constructor(data) {
  540. this._types = new Map();
  541. this._attributes = new Map();
  542. if (data) {
  543. const metadata = JSON.parse(data);
  544. this._types = new Map(metadata.map((item) => [ item.name, item ]));
  545. }
  546. }
  547. type(name) {
  548. return this._types.get(name);
  549. }
  550. attribute(type, name) {
  551. const key = type + ':' + name;
  552. if (!this._attributes.has(key)) {
  553. const schema = this.type(type);
  554. if (schema && schema.attributes && schema.attributes.length > 0) {
  555. for (const attribute of schema.attributes) {
  556. this._attributes.set(type + ':' + attribute.name, attribute);
  557. }
  558. }
  559. if (!this._attributes.has(key)) {
  560. this._attributes.set(key, null);
  561. }
  562. }
  563. return this._attributes.get(key);
  564. }
  565. };
  566. sklearn.Utility = class {
  567. static isTensor(obj) {
  568. return obj && obj.__class__ && obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray';
  569. }
  570. static findWeights(obj) {
  571. const keys = [ '', 'blobs' ];
  572. for (const key of keys) {
  573. const dict = key === '' ? obj : obj[key];
  574. if (dict) {
  575. const weights = new Map();
  576. if (dict instanceof Map) {
  577. for (const pair of dict) {
  578. if (!sklearn.Utility.isTensor(pair[1])) {
  579. return null;
  580. }
  581. weights.set(pair[0], pair[1]);
  582. }
  583. return weights;
  584. }
  585. else if (!Array.isArray(dict)) {
  586. for (const key in dict) {
  587. const value = dict[key];
  588. if (key != 'weight_order' && key != 'lr') {
  589. if (!key || !sklearn.Utility.isTensor(value)) {
  590. return null;
  591. }
  592. weights.set(key, value);
  593. }
  594. }
  595. return weights;
  596. }
  597. }
  598. }
  599. for (const key of keys) {
  600. const list = key === '' ? obj : obj[key];
  601. if (list && Array.isArray(list)) {
  602. const weights = new Map();
  603. for (let i = 0; i < list.length; i++) {
  604. const value = list[i];
  605. if (!sklearn.Utility.isTensor(value, 'numpy.ndarray')) {
  606. return null;
  607. }
  608. weights.set(i.toString(), value);
  609. }
  610. return weights;
  611. }
  612. }
  613. return null;
  614. }
  615. };
  616. sklearn.Error = class extends Error {
  617. constructor(message) {
  618. super(message);
  619. this.name = 'Error loading scikit-learn model.';
  620. }
  621. };
  622. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  623. module.exports.ModelFactory = sklearn.ModelFactory;
  624. }