| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- /**
- * Copyright (c) 2016-present, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
- import fastTextModularized from './fasttext_wasm.js';
- const fastTextModule = fastTextModularized();
- let postRunFunc = null;
- const addOnPostRun = function(func) {
- postRunFunc = func;
- };
- fastTextModule.addOnPostRun(() => {
- if (postRunFunc) {
- postRunFunc();
- }
- });
- const thisModule = this;
- const trainFileInWasmFs = 'train.txt';
- const testFileInWasmFs = 'test.txt';
- const modelFileInWasmFs = 'model.bin';
- const getFloat32ArrayFromHeap = (len) => {
- const dataBytes = len * Float32Array.BYTES_PER_ELEMENT;
- const dataPtr = fastTextModule._malloc(dataBytes);
- const dataHeap = new Uint8Array(fastTextModule.HEAPU8.buffer,
- dataPtr,
- dataBytes);
- return {
- 'ptr':dataHeap.byteOffset,
- 'size':len,
- 'buffer':dataHeap.buffer
- };
- };
- const heapToFloat32 = (r) => new Float32Array(r.buffer, r.ptr, r.size);
- class FastText {
- constructor() {
- this.f = new fastTextModule.FastText();
- }
- /**
- * loadModel
- *
- * Loads the model file from the specified url, and returns the
- * corresponding `FastTextModel` object.
- *
- * @param {string} url
- * the url of the model file.
- *
- * @return {Promise} promise object that resolves to a `FastTextModel`
- *
- */
- loadModel(url) {
- const fetchFunc = (thisModule && thisModule.fetch) || fetch;
- const fastTextNative = this.f;
- return new Promise(function(resolve, reject) {
- fetchFunc(url).then(response => {
- return response.arrayBuffer();
- }).then(bytes => {
- const byteArray = new Uint8Array(bytes);
- const FS = fastTextModule.FS;
- FS.writeFile(modelFileInWasmFs, byteArray);
- }).then(() => {
- fastTextNative.loadModel(modelFileInWasmFs);
- resolve(new FastTextModel(fastTextNative));
- }).catch(error => {
- reject(error);
- });
- });
- }
- _train(url, modelName, kwargs = {}, callback = null) {
- const fetchFunc = (thisModule && thisModule.fetch) || fetch;
- const fastTextNative = this.f;
- return new Promise(function(resolve, reject) {
- fetchFunc(url).then(response => {
- return response.arrayBuffer();
- }).then(bytes => {
- const byteArray = new Uint8Array(bytes);
- const FS = fastTextModule.FS;
- FS.writeFile(trainFileInWasmFs, byteArray);
- }).then(() => {
- const argsList = ['lr', 'lrUpdateRate', 'dim', 'ws', 'epoch',
- 'minCount', 'minCountLabel', 'neg', 'wordNgrams', 'loss',
- 'model', 'bucket', 'minn', 'maxn', 't', 'label', 'verbose',
- 'pretrainedVectors', 'saveOutput', 'seed', 'qout', 'retrain',
- 'qnorm', 'cutoff', 'dsub', 'qnorm', 'autotuneValidationFile',
- 'autotuneMetric', 'autotunePredictions', 'autotuneDuration',
- 'autotuneModelSize'];
- const args = new fastTextModule.Args();
- argsList.forEach(k => {
- if (k in kwargs) {
- args[k] = kwargs[k];
- }
- });
- args.model = fastTextModule.ModelName[modelName];
- args.loss = ('loss' in kwargs) ?
- fastTextModule.LossName[kwargs['loss']] : 'hs';
- args.thread = 1;
- args.input = trainFileInWasmFs;
- fastTextNative.train(args, callback);
- resolve(new FastTextModel(fastTextNative));
- }).catch(error => {
- reject(error);
- });
- });
- }
- /**
- * trainSupervised
- *
- * Downloads the input file from the specified url, trains a supervised
- * model and returns a `FastTextModel` object.
- *
- * @param {string} url
- * the url of the input file.
- * The input file must must contain at least one label per line. For an
- * example consult the example datasets which are part of the fastText
- * repository such as the dataset pulled by classification-example.sh.
- *
- * @param {dict} kwargs
- * train parameters.
- * For example {'lr': 0.5, 'epoch': 5}
- *
- * @param {function} callback
- * train callback function
- * `callback` function is called regularly from the train loop:
- * `callback(progress, loss, wordsPerSec, learningRate, eta)`
- *
- * @return {Promise} promise object that resolves to a `FastTextModel`
- *
- */
- trainSupervised(url, kwargs = {}, callback) {
- const self = this;
- return new Promise(function(resolve, reject) {
- self._train(url, 'supervised', kwargs, callback).then(model => {
- resolve(model);
- }).catch(error => {
- reject(error);
- });
- });
- }
- /**
- * trainUnsupervised
- *
- * Downloads the input file from the specified url, trains an unsupervised
- * model and returns a `FastTextModel` object.
- *
- * @param {string} url
- * the url of the input file.
- * The input file must not contain any labels or use the specified label
- * prefixunless it is ok for those words to be ignored. For an example
- * consult the dataset pulled by the example script word-vector-example.sh
- * which is part of the fastText repository.
- *
- * @param {string} modelName
- * Model to be used for unsupervised learning. `cbow` or `skipgram`.
- *
- * @param {dict} kwargs
- * train parameters.
- * For example {'lr': 0.5, 'epoch': 5}
- *
- * @param {function} callback
- * train callback function
- * `callback` function is called regularly from the train loop:
- * `callback(progress, loss, wordsPerSec, learningRate, eta)`
- *
- * @return {Promise} promise object that resolves to a `FastTextModel`
- *
- */
- trainUnsupervised(url, modelName, kwargs = {}, callback) {
- const self = this;
- return new Promise(function(resolve, reject) {
- self._train(url, modelName, kwargs, callback).then(model => {
- resolve(model);
- }).catch(error => {
- reject(error);
- });
- });
- }
- }
- class FastTextModel {
- /**
- * `FastTextModel` represents a trained model.
- *
- * @constructor
- *
- * @param {object} fastTextNative
- * webassembly object that makes the bridge between js and C++
- */
- constructor(fastTextNative) {
- this.f = fastTextNative;
- }
- /**
- * isQuant
- *
- * @return {bool} true if the model is quantized
- *
- */
- isQuant() {
- return this.f.isQuant;
- }
- /**
- * getDimension
- *
- * @return {int} the dimension (size) of a lookup vector (hidden layer)
- *
- */
- getDimension() {
- return this.f.args.dim;
- }
- /**
- * getWordVector
- *
- * @param {string} word
- *
- * @return {Float32Array} the vector representation of `word`.
- *
- */
- getWordVector(word) {
- const b = getFloat32ArrayFromHeap(this.getDimension());
- this.f.getWordVector(b, word);
- return heapToFloat32(b);
- }
- /**
- * getSentenceVector
- *
- * @param {string} text
- *
- * @return {Float32Array} the vector representation of `text`.
- *
- */
- getSentenceVector(text) {
- if (text.indexOf('\n') != -1) {
- "sentence vector processes one line at a time (remove '\\n')";
- }
- text += '\n';
- const b = getFloat32ArrayFromHeap(this.getDimension());
- this.f.getSentenceVector(b, text);
- return heapToFloat32(b);
- }
- /**
- * getNearestNeighbors
- *
- * returns the nearest `k` neighbors of `word`.
- *
- * @param {string} word
- * @param {int} k
- *
- * @return {Array.<Pair.<number, string>>}
- * words and their corresponding cosine similarities.
- *
- */
- getNearestNeighbors(word, k = 10) {
- return this.f.getNN(word, k);
- }
- /**
- * getAnalogies
- *
- * returns the nearest `k` neighbors of the operation
- * `wordA - wordB + wordC`.
- *
- * @param {string} wordA
- * @param {string} wordB
- * @param {string} wordC
- * @param {int} k
- *
- * @return {Array.<Pair.<number, string>>}
- * words and their corresponding cosine similarities
- *
- */
- getAnalogies(wordA, wordB, wordC, k) {
- return this.f.getAnalogies(k, wordA, wordB, wordC);
- }
- /**
- * getWordId
- *
- * Given a word, get the word id within the dictionary.
- * Returns -1 if word is not in the dictionary.
- *
- * @return {int} word id
- *
- */
- getWordId(word) {
- return this.f.getWordId(word);
- }
- /**
- * getSubwordId
- *
- * Given a subword, return the index (within input matrix) it hashes to.
- *
- * @return {int} subword id
- *
- */
- getSubwordId(subword) {
- return this.f.getSubwordId(subword);
- }
- /**
- * getSubwords
- *
- * returns the subwords and their indicies.
- *
- * @param {string} word
- *
- * @return {Pair.<Array.<string>, Array.<int>>}
- * words and their corresponding indicies
- *
- */
- getSubwords(word) {
- return this.f.getSubwords(word);
- }
- /**
- * getInputVector
- *
- * Given an index, get the corresponding vector of the Input Matrix.
- *
- * @param {int} ind
- *
- * @return {Float32Array} the vector of the `ind`'th index
- *
- */
- getInputVector(ind) {
- const b = getFloat32ArrayFromHeap(this.getDimension());
- this.f.getInputVector(b, ind);
- return heapToFloat32(b);
- }
- /**
- * predict
- *
- * Given a string, get a list of labels and a list of corresponding
- * probabilities. k controls the number of returned labels.
- *
- * @param {string} text
- * @param {int} k, the number of predictions to be returned
- * @param {number} probability threshold
- *
- * @return {Array.<Pair.<number, string>>}
- * labels and their probabilities
- *
- */
- predict(text, k = 1, threshold = 0.0) {
- return this.f.predict(text, k, threshold);
- }
- /**
- * getInputMatrix
- *
- * Get a reference to the full input matrix of a Model. This only
- * works if the model is not quantized.
- *
- * @return {DenseMatrix}
- * densematrix with functions: `rows`, `cols`, `at(i,j)`
- *
- * example:
- * let inputMatrix = model.getInputMatrix();
- * let value = inputMatrix.at(1, 2);
- */
- getInputMatrix() {
- if (this.isQuant()) {
- throw new Error("Can't get quantized Matrix");
- }
- return this.f.getInputMatrix();
- }
- /**
- * getOutputMatrix
- *
- * Get a reference to the full input matrix of a Model. This only
- * works if the model is not quantized.
- *
- * @return {DenseMatrix}
- * densematrix with functions: `rows`, `cols`, `at(i,j)`
- *
- * example:
- * let outputMatrix = model.getOutputMatrix();
- * let value = outputMatrix.at(1, 2);
- */
- getOutputMatrix() {
- if (this.isQuant()) {
- throw new Error("Can't get quantized Matrix");
- }
- return this.f.getOutputMatrix();
- }
- /**
- * getWords
- *
- * Get the entire list of words of the dictionary including the frequency
- * of the individual words. This does not include any subwords. For that
- * please consult the function get_subwords.
- *
- * @return {Pair.<Array.<string>, Array.<int>>}
- * words and their corresponding frequencies
- *
- */
- getWords() {
- return this.f.getWords();
- }
- /**
- * getLabels
- *
- * Get the entire list of labels of the dictionary including the frequency
- * of the individual labels.
- *
- * @return {Pair.<Array.<string>, Array.<int>>}
- * labels and their corresponding frequencies
- *
- */
- getLabels() {
- return this.f.getLabels();
- }
- /**
- * getLine
- *
- * Split a line of text into words and labels. Labels must start with
- * the prefix used to create the model (__label__ by default).
- *
- * @param {string} text
- *
- * @return {Pair.<Array.<string>, Array.<string>>}
- * words and labels
- *
- */
- getLine(text) {
- return this.f.getLine(text);
- }
- /**
- * saveModel
- *
- * Saves the model file in web assembly in-memory FS and returns a blob
- *
- * @return {Blob} blob data of the file saved in web assembly FS
- *
- */
- saveModel() {
- this.f.saveModel(modelFileInWasmFs);
- const content = fastTextModule.FS.readFile(modelFileInWasmFs,
- { encoding: 'binary' });
- return new Blob(
- [new Uint8Array(content, content.byteOffset, content.length)],
- { type: ' application/octet-stream' }
- );
- }
- /**
- * test
- *
- * Downloads the test file from the specified url, evaluates the supervised
- * model with it.
- *
- * @param {string} url
- * @param {int} k, the number of predictions to be returned
- * @param {number} probability threshold
- *
- * @return {Promise} promise object that resolves to a `Meter` object
- *
- * example:
- * model.test("/absolute/url/to/test.txt", 1, 0.0).then((meter) => {
- * console.log(meter.precision);
- * console.log(meter.recall);
- * console.log(meter.f1Score);
- * console.log(meter.nexamples());
- * });
- *
- */
- test(url, k, threshold) {
- const fetchFunc = (thisModule && thisModule.fetch) || fetch;
- const fastTextNative = this.f;
- return new Promise(function(resolve, reject) {
- fetchFunc(url).then(response => {
- return response.arrayBuffer();
- }).then(bytes => {
- const byteArray = new Uint8Array(bytes);
- const FS = fastTextModule.FS;
- FS.writeFile(testFileInWasmFs, byteArray);
- }).then(() => {
- const meter = fastTextNative.test(testFileInWasmFs, k, threshold);
- resolve(meter);
- }).catch(error => {
- reject(error);
- });
- });
- }
- }
- export {FastText, addOnPostRun};
|