train_supervised.html 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. <!DOCTYPE html>
  2. <html>
  3. <head>
  4. <meta charset="UTF-8">
  5. <meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1.0, maximum-scale=1.0, user-scalable=no">
  6. </head>
  7. <body>
  8. <script type="module">
  9. const printVector = function(predictions, limit) {
  10. limit = limit || Infinity;
  11. for (let i=0; i<predictions.size() && i<limit; i++){
  12. let prediction = predictions.get(i);
  13. console.log(predictions.get(i));
  14. }
  15. }
  16. const trainCallback = (progress, loss, wst, lr, eta) => {
  17. console.log([progress, loss, wst, lr, eta]);
  18. };
  19. import {FastText, addOnPostRun} from "./fasttext.js";
  20. addOnPostRun(() => {
  21. let ft = new FastText();
  22. ft.trainSupervised("cooking.train", {
  23. 'lr':1.0,
  24. 'epoch':10,
  25. 'loss':'hs',
  26. 'wordNgrams':2,
  27. 'dim':50,
  28. 'bucket':200000
  29. }, trainCallback).then(model => {
  30. console.log('Trained.');
  31. printVector(model.predict("Which baking dish is best to bake a banana bread ?", 5, 0.0));
  32. /* getInputMatrix */
  33. let inputMatrix = model.getInputMatrix();
  34. console.log(inputMatrix.cols());
  35. console.log(inputMatrix.rows());
  36. console.log(inputMatrix.at(1, 2));
  37. /* getOutputMatrix */
  38. let outputMatrix = model.getOutputMatrix();
  39. console.log(outputMatrix.cols());
  40. console.log(outputMatrix.rows());
  41. console.log(outputMatrix.at(1, 2));
  42. /* getWords */
  43. let wordsInformation = model.getWords();
  44. printVector(wordsInformation[0], 30); // words
  45. printVector(wordsInformation[1], 30); // frequencies
  46. /* getLabels */
  47. let labelsInformation = model.getLabels();
  48. printVector(labelsInformation[0], 30); // labels
  49. printVector(labelsInformation[1], 30); // frequencies
  50. });
  51. });
  52. </script>
  53. </body>
  54. </html>