fasttext.js 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. import fastTextModularized from './fasttext_wasm.js';
  9. const fastTextModule = fastTextModularized();
  10. let postRunFunc = null;
  11. const addOnPostRun = function(func) {
  12. postRunFunc = func;
  13. };
  14. fastTextModule.addOnPostRun(() => {
  15. if (postRunFunc) {
  16. postRunFunc();
  17. }
  18. });
  19. const thisModule = this;
  20. const trainFileInWasmFs = 'train.txt';
  21. const testFileInWasmFs = 'test.txt';
  22. const modelFileInWasmFs = 'model.bin';
  23. const getFloat32ArrayFromHeap = (len) => {
  24. const dataBytes = len * Float32Array.BYTES_PER_ELEMENT;
  25. const dataPtr = fastTextModule._malloc(dataBytes);
  26. const dataHeap = new Uint8Array(fastTextModule.HEAPU8.buffer,
  27. dataPtr,
  28. dataBytes);
  29. return {
  30. 'ptr':dataHeap.byteOffset,
  31. 'size':len,
  32. 'buffer':dataHeap.buffer
  33. };
  34. };
  35. const heapToFloat32 = (r) => new Float32Array(r.buffer, r.ptr, r.size);
  36. class FastText {
  37. constructor() {
  38. this.f = new fastTextModule.FastText();
  39. }
  40. /**
  41. * loadModel
  42. *
  43. * Loads the model file from the specified url, and returns the
  44. * corresponding `FastTextModel` object.
  45. *
  46. * @param {string} url
  47. * the url of the model file.
  48. *
  49. * @return {Promise} promise object that resolves to a `FastTextModel`
  50. *
  51. */
  52. loadModel(url) {
  53. const fetchFunc = (thisModule && thisModule.fetch) || fetch;
  54. const fastTextNative = this.f;
  55. return new Promise(function(resolve, reject) {
  56. fetchFunc(url).then(response => {
  57. return response.arrayBuffer();
  58. }).then(bytes => {
  59. const byteArray = new Uint8Array(bytes);
  60. const FS = fastTextModule.FS;
  61. FS.writeFile(modelFileInWasmFs, byteArray);
  62. }).then(() => {
  63. fastTextNative.loadModel(modelFileInWasmFs);
  64. resolve(new FastTextModel(fastTextNative));
  65. }).catch(error => {
  66. reject(error);
  67. });
  68. });
  69. }
  70. _train(url, modelName, kwargs = {}, callback = null) {
  71. const fetchFunc = (thisModule && thisModule.fetch) || fetch;
  72. const fastTextNative = this.f;
  73. return new Promise(function(resolve, reject) {
  74. fetchFunc(url).then(response => {
  75. return response.arrayBuffer();
  76. }).then(bytes => {
  77. const byteArray = new Uint8Array(bytes);
  78. const FS = fastTextModule.FS;
  79. FS.writeFile(trainFileInWasmFs, byteArray);
  80. }).then(() => {
  81. const argsList = ['lr', 'lrUpdateRate', 'dim', 'ws', 'epoch',
  82. 'minCount', 'minCountLabel', 'neg', 'wordNgrams', 'loss',
  83. 'model', 'bucket', 'minn', 'maxn', 't', 'label', 'verbose',
  84. 'pretrainedVectors', 'saveOutput', 'seed', 'qout', 'retrain',
  85. 'qnorm', 'cutoff', 'dsub', 'qnorm', 'autotuneValidationFile',
  86. 'autotuneMetric', 'autotunePredictions', 'autotuneDuration',
  87. 'autotuneModelSize'];
  88. const args = new fastTextModule.Args();
  89. argsList.forEach(k => {
  90. if (k in kwargs) {
  91. args[k] = kwargs[k];
  92. }
  93. });
  94. args.model = fastTextModule.ModelName[modelName];
  95. args.loss = ('loss' in kwargs) ?
  96. fastTextModule.LossName[kwargs['loss']] : 'hs';
  97. args.thread = 1;
  98. args.input = trainFileInWasmFs;
  99. fastTextNative.train(args, callback);
  100. resolve(new FastTextModel(fastTextNative));
  101. }).catch(error => {
  102. reject(error);
  103. });
  104. });
  105. }
  106. /**
  107. * trainSupervised
  108. *
  109. * Downloads the input file from the specified url, trains a supervised
  110. * model and returns a `FastTextModel` object.
  111. *
  112. * @param {string} url
  113. * the url of the input file.
  114. * The input file must must contain at least one label per line. For an
  115. * example consult the example datasets which are part of the fastText
  116. * repository such as the dataset pulled by classification-example.sh.
  117. *
  118. * @param {dict} kwargs
  119. * train parameters.
  120. * For example {'lr': 0.5, 'epoch': 5}
  121. *
  122. * @param {function} callback
  123. * train callback function
  124. * `callback` function is called regularly from the train loop:
  125. * `callback(progress, loss, wordsPerSec, learningRate, eta)`
  126. *
  127. * @return {Promise} promise object that resolves to a `FastTextModel`
  128. *
  129. */
  130. trainSupervised(url, kwargs = {}, callback) {
  131. const self = this;
  132. return new Promise(function(resolve, reject) {
  133. self._train(url, 'supervised', kwargs, callback).then(model => {
  134. resolve(model);
  135. }).catch(error => {
  136. reject(error);
  137. });
  138. });
  139. }
  140. /**
  141. * trainUnsupervised
  142. *
  143. * Downloads the input file from the specified url, trains an unsupervised
  144. * model and returns a `FastTextModel` object.
  145. *
  146. * @param {string} url
  147. * the url of the input file.
  148. * The input file must not contain any labels or use the specified label
  149. * prefixunless it is ok for those words to be ignored. For an example
  150. * consult the dataset pulled by the example script word-vector-example.sh
  151. * which is part of the fastText repository.
  152. *
  153. * @param {string} modelName
  154. * Model to be used for unsupervised learning. `cbow` or `skipgram`.
  155. *
  156. * @param {dict} kwargs
  157. * train parameters.
  158. * For example {'lr': 0.5, 'epoch': 5}
  159. *
  160. * @param {function} callback
  161. * train callback function
  162. * `callback` function is called regularly from the train loop:
  163. * `callback(progress, loss, wordsPerSec, learningRate, eta)`
  164. *
  165. * @return {Promise} promise object that resolves to a `FastTextModel`
  166. *
  167. */
  168. trainUnsupervised(url, modelName, kwargs = {}, callback) {
  169. const self = this;
  170. return new Promise(function(resolve, reject) {
  171. self._train(url, modelName, kwargs, callback).then(model => {
  172. resolve(model);
  173. }).catch(error => {
  174. reject(error);
  175. });
  176. });
  177. }
  178. }
  179. class FastTextModel {
  180. /**
  181. * `FastTextModel` represents a trained model.
  182. *
  183. * @constructor
  184. *
  185. * @param {object} fastTextNative
  186. * webassembly object that makes the bridge between js and C++
  187. */
  188. constructor(fastTextNative) {
  189. this.f = fastTextNative;
  190. }
  191. /**
  192. * isQuant
  193. *
  194. * @return {bool} true if the model is quantized
  195. *
  196. */
  197. isQuant() {
  198. return this.f.isQuant;
  199. }
  200. /**
  201. * getDimension
  202. *
  203. * @return {int} the dimension (size) of a lookup vector (hidden layer)
  204. *
  205. */
  206. getDimension() {
  207. return this.f.args.dim;
  208. }
  209. /**
  210. * getWordVector
  211. *
  212. * @param {string} word
  213. *
  214. * @return {Float32Array} the vector representation of `word`.
  215. *
  216. */
  217. getWordVector(word) {
  218. const b = getFloat32ArrayFromHeap(this.getDimension());
  219. this.f.getWordVector(b, word);
  220. return heapToFloat32(b);
  221. }
  222. /**
  223. * getSentenceVector
  224. *
  225. * @param {string} text
  226. *
  227. * @return {Float32Array} the vector representation of `text`.
  228. *
  229. */
  230. getSentenceVector(text) {
  231. if (text.indexOf('\n') != -1) {
  232. "sentence vector processes one line at a time (remove '\\n')";
  233. }
  234. text += '\n';
  235. const b = getFloat32ArrayFromHeap(this.getDimension());
  236. this.f.getSentenceVector(b, text);
  237. return heapToFloat32(b);
  238. }
  239. /**
  240. * getNearestNeighbors
  241. *
  242. * returns the nearest `k` neighbors of `word`.
  243. *
  244. * @param {string} word
  245. * @param {int} k
  246. *
  247. * @return {Array.<Pair.<number, string>>}
  248. * words and their corresponding cosine similarities.
  249. *
  250. */
  251. getNearestNeighbors(word, k = 10) {
  252. return this.f.getNN(word, k);
  253. }
  254. /**
  255. * getAnalogies
  256. *
  257. * returns the nearest `k` neighbors of the operation
  258. * `wordA - wordB + wordC`.
  259. *
  260. * @param {string} wordA
  261. * @param {string} wordB
  262. * @param {string} wordC
  263. * @param {int} k
  264. *
  265. * @return {Array.<Pair.<number, string>>}
  266. * words and their corresponding cosine similarities
  267. *
  268. */
  269. getAnalogies(wordA, wordB, wordC, k) {
  270. return this.f.getAnalogies(k, wordA, wordB, wordC);
  271. }
  272. /**
  273. * getWordId
  274. *
  275. * Given a word, get the word id within the dictionary.
  276. * Returns -1 if word is not in the dictionary.
  277. *
  278. * @return {int} word id
  279. *
  280. */
  281. getWordId(word) {
  282. return this.f.getWordId(word);
  283. }
  284. /**
  285. * getSubwordId
  286. *
  287. * Given a subword, return the index (within input matrix) it hashes to.
  288. *
  289. * @return {int} subword id
  290. *
  291. */
  292. getSubwordId(subword) {
  293. return this.f.getSubwordId(subword);
  294. }
  295. /**
  296. * getSubwords
  297. *
  298. * returns the subwords and their indicies.
  299. *
  300. * @param {string} word
  301. *
  302. * @return {Pair.<Array.<string>, Array.<int>>}
  303. * words and their corresponding indicies
  304. *
  305. */
  306. getSubwords(word) {
  307. return this.f.getSubwords(word);
  308. }
  309. /**
  310. * getInputVector
  311. *
  312. * Given an index, get the corresponding vector of the Input Matrix.
  313. *
  314. * @param {int} ind
  315. *
  316. * @return {Float32Array} the vector of the `ind`'th index
  317. *
  318. */
  319. getInputVector(ind) {
  320. const b = getFloat32ArrayFromHeap(this.getDimension());
  321. this.f.getInputVector(b, ind);
  322. return heapToFloat32(b);
  323. }
  324. /**
  325. * predict
  326. *
  327. * Given a string, get a list of labels and a list of corresponding
  328. * probabilities. k controls the number of returned labels.
  329. *
  330. * @param {string} text
  331. * @param {int} k, the number of predictions to be returned
  332. * @param {number} probability threshold
  333. *
  334. * @return {Array.<Pair.<number, string>>}
  335. * labels and their probabilities
  336. *
  337. */
  338. predict(text, k = 1, threshold = 0.0) {
  339. return this.f.predict(text, k, threshold);
  340. }
  341. /**
  342. * getInputMatrix
  343. *
  344. * Get a reference to the full input matrix of a Model. This only
  345. * works if the model is not quantized.
  346. *
  347. * @return {DenseMatrix}
  348. * densematrix with functions: `rows`, `cols`, `at(i,j)`
  349. *
  350. * example:
  351. * let inputMatrix = model.getInputMatrix();
  352. * let value = inputMatrix.at(1, 2);
  353. */
  354. getInputMatrix() {
  355. if (this.isQuant()) {
  356. throw new Error("Can't get quantized Matrix");
  357. }
  358. return this.f.getInputMatrix();
  359. }
  360. /**
  361. * getOutputMatrix
  362. *
  363. * Get a reference to the full input matrix of a Model. This only
  364. * works if the model is not quantized.
  365. *
  366. * @return {DenseMatrix}
  367. * densematrix with functions: `rows`, `cols`, `at(i,j)`
  368. *
  369. * example:
  370. * let outputMatrix = model.getOutputMatrix();
  371. * let value = outputMatrix.at(1, 2);
  372. */
  373. getOutputMatrix() {
  374. if (this.isQuant()) {
  375. throw new Error("Can't get quantized Matrix");
  376. }
  377. return this.f.getOutputMatrix();
  378. }
  379. /**
  380. * getWords
  381. *
  382. * Get the entire list of words of the dictionary including the frequency
  383. * of the individual words. This does not include any subwords. For that
  384. * please consult the function get_subwords.
  385. *
  386. * @return {Pair.<Array.<string>, Array.<int>>}
  387. * words and their corresponding frequencies
  388. *
  389. */
  390. getWords() {
  391. return this.f.getWords();
  392. }
  393. /**
  394. * getLabels
  395. *
  396. * Get the entire list of labels of the dictionary including the frequency
  397. * of the individual labels.
  398. *
  399. * @return {Pair.<Array.<string>, Array.<int>>}
  400. * labels and their corresponding frequencies
  401. *
  402. */
  403. getLabels() {
  404. return this.f.getLabels();
  405. }
  406. /**
  407. * getLine
  408. *
  409. * Split a line of text into words and labels. Labels must start with
  410. * the prefix used to create the model (__label__ by default).
  411. *
  412. * @param {string} text
  413. *
  414. * @return {Pair.<Array.<string>, Array.<string>>}
  415. * words and labels
  416. *
  417. */
  418. getLine(text) {
  419. return this.f.getLine(text);
  420. }
  421. /**
  422. * saveModel
  423. *
  424. * Saves the model file in web assembly in-memory FS and returns a blob
  425. *
  426. * @return {Blob} blob data of the file saved in web assembly FS
  427. *
  428. */
  429. saveModel() {
  430. this.f.saveModel(modelFileInWasmFs);
  431. const content = fastTextModule.FS.readFile(modelFileInWasmFs,
  432. { encoding: 'binary' });
  433. return new Blob(
  434. [new Uint8Array(content, content.byteOffset, content.length)],
  435. { type: ' application/octet-stream' }
  436. );
  437. }
  438. /**
  439. * test
  440. *
  441. * Downloads the test file from the specified url, evaluates the supervised
  442. * model with it.
  443. *
  444. * @param {string} url
  445. * @param {int} k, the number of predictions to be returned
  446. * @param {number} probability threshold
  447. *
  448. * @return {Promise} promise object that resolves to a `Meter` object
  449. *
  450. * example:
  451. * model.test("/absolute/url/to/test.txt", 1, 0.0).then((meter) => {
  452. * console.log(meter.precision);
  453. * console.log(meter.recall);
  454. * console.log(meter.f1Score);
  455. * console.log(meter.nexamples());
  456. * });
  457. *
  458. */
  459. test(url, k, threshold) {
  460. const fetchFunc = (thisModule && thisModule.fetch) || fetch;
  461. const fastTextNative = this.f;
  462. return new Promise(function(resolve, reject) {
  463. fetchFunc(url).then(response => {
  464. return response.arrayBuffer();
  465. }).then(bytes => {
  466. const byteArray = new Uint8Array(bytes);
  467. const FS = fastTextModule.FS;
  468. FS.writeFile(testFileInWasmFs, byteArray);
  469. }).then(() => {
  470. const meter = fastTextNative.test(testFileInWasmFs, k, threshold);
  471. resolve(meter);
  472. }).catch(error => {
  473. reject(error);
  474. });
  475. });
  476. }
  477. }
  478. export {FastText, addOnPostRun};