benchmark.js 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /*
  2. * Copyright (C) 2017 Apple Inc. All rights reserved.
  3. *
  4. * Redistribution and use in source and binary forms, with or without
  5. * modification, are permitted provided that the following conditions
  6. * are met:
  7. * 1. Redistributions of source code must retain the above copyright
  8. * notice, this list of conditions and the following disclaimer.
  9. * 2. Redistributions in binary form must reproduce the above copyright
  10. * notice, this list of conditions and the following disclaimer in the
  11. * documentation and/or other materials provided with the distribution.
  12. *
  13. * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
  14. * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  15. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  16. * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR
  17. * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  18. * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  19. * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  20. * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
  21. * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  22. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  23. * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  24. */
  25. "use strict";
  26. let currentTime;
  27. if (this.performance && performance.now)
  28. currentTime = function() { return performance.now() };
  29. else if (this.preciseTime)
  30. currentTime = function() { return preciseTime() * 1000; };
  31. else
  32. currentTime = function() { return +new Date(); };
  33. class MLBenchmark {
  34. constructor() { }
  35. runIteration()
  36. {
  37. let Matrix = MLMatrix;
  38. let ACTIVATION_FUNCTIONS = FeedforwardNeuralNetworksActivationFunctions;
  39. function run() {
  40. let it = (name, f) => {
  41. f();
  42. };
  43. function assert(b) {
  44. if (!b)
  45. throw new Error("Bad");
  46. }
  47. var functions = Object.keys(ACTIVATION_FUNCTIONS);
  48. it('Training the neural network with XOR operator', function () {
  49. var trainingSet = new Matrix([[0, 0], [0, 1], [1, 0], [1, 1]]);
  50. var predictions = [false, true, true, false];
  51. for (var i = 0; i < functions.length; ++i) {
  52. var options = {
  53. hiddenLayers: [4],
  54. iterations: 40,
  55. learningRate: 0.3,
  56. activation: functions[i]
  57. };
  58. var xorNN = new FeedforwardNeuralNetwork(options);
  59. xorNN.train(trainingSet, predictions);
  60. var results = xorNN.predict(trainingSet);
  61. }
  62. });
  63. it('Training the neural network with AND operator', function () {
  64. var trainingSet = [[0, 0], [0, 1], [1, 0], [1, 1]];
  65. var predictions = [[1, 0], [1, 0], [1, 0], [0, 1]];
  66. for (var i = 0; i < functions.length; ++i) {
  67. var options = {
  68. hiddenLayers: [3],
  69. iterations: 75,
  70. learningRate: 0.3,
  71. activation: functions[i]
  72. };
  73. var andNN = new FeedforwardNeuralNetwork(options);
  74. andNN.train(trainingSet, predictions);
  75. var results = andNN.predict(trainingSet);
  76. }
  77. });
  78. it('Export and import', function () {
  79. var trainingSet = [[0, 0], [0, 1], [1, 0], [1, 1]];
  80. var predictions = [0, 1, 1, 1];
  81. for (var i = 0; i < functions.length; ++i) {
  82. var options = {
  83. hiddenLayers: [4],
  84. iterations: 40,
  85. learningRate: 0.3,
  86. activation: functions[i]
  87. };
  88. var orNN = new FeedforwardNeuralNetwork(options);
  89. orNN.train(trainingSet, predictions);
  90. var model = JSON.parse(JSON.stringify(orNN));
  91. var networkNN = FeedforwardNeuralNetwork.load(model);
  92. var results = networkNN.predict(trainingSet);
  93. }
  94. });
  95. it('Multiclass clasification', function () {
  96. var trainingSet = [[0, 0], [0, 1], [1, 0], [1, 1]];
  97. var predictions = [2, 0, 1, 0];
  98. for (var i = 0; i < functions.length; ++i) {
  99. var options = {
  100. hiddenLayers: [4],
  101. iterations: 40,
  102. learningRate: 0.5,
  103. activation: functions[i]
  104. };
  105. var nn = new FeedforwardNeuralNetwork(options);
  106. nn.train(trainingSet, predictions);
  107. var result = nn.predict(trainingSet);
  108. }
  109. });
  110. it('Big case', function () {
  111. var trainingSet = [[1, 1], [1, 2], [2, 1], [2, 2], [3, 1], [1, 3], [1, 4], [4, 1],
  112. [6, 1], [6, 2], [6, 3], [6, 4], [6, 5], [5, 5], [4, 5], [3, 5]];
  113. var predictions = [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0],
  114. [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]];
  115. for (var i = 0; i < functions.length; ++i) {
  116. var options = {
  117. hiddenLayers: [20],
  118. iterations: 60,
  119. learningRate: 0.01,
  120. activation: functions[i]
  121. };
  122. var nn = new FeedforwardNeuralNetwork(options);
  123. nn.train(trainingSet, predictions);
  124. var result = nn.predict([[5, 4]]);
  125. assert(result[0][0] < result[0][1]);
  126. }
  127. });
  128. }
  129. run();
  130. }
  131. }
  132. function runBenchmark()
  133. {
  134. const numIterations = 60;
  135. let before = currentTime();
  136. let benchmark = new MLBenchmark();
  137. for (let iteration = 0; iteration < numIterations; ++iteration)
  138. benchmark.runIteration();
  139. let after = currentTime();
  140. return after - before;
  141. }