lasagne.js 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. // Experimental
  2. const lasagne = {};
  3. lasagne.ModelFactory = class {
  4. async match(context) {
  5. const obj = await context.peek('pkl');
  6. if (obj && obj.__class__ && obj.__class__.__module__ === 'nolearn.lasagne.base' && obj.__class__.__name__ === 'NeuralNet') {
  7. return context.set('lasagne', obj);
  8. }
  9. return null;
  10. }
  11. async open(context) {
  12. const metadata = await context.metadata('lasagne-metadata.json');
  13. return new lasagne.Model(metadata, context.value);
  14. }
  15. };
  16. lasagne.Model = class {
  17. constructor(metadata, model) {
  18. this.format = 'Lasagne';
  19. this.modules = [new lasagne.Graph(metadata, model)];
  20. }
  21. };
  22. lasagne.Graph = class {
  23. constructor(metadata, model) {
  24. this.nodes = [];
  25. this.inputs = [];
  26. this.outputs = [];
  27. const values = new Map();
  28. values.map = (name, type, tensor) => {
  29. if (!values.has(name)) {
  30. values.set(name, new lasagne.Value(name, type, tensor));
  31. } else if (tensor) {
  32. throw new lasagne.Error(`Duplicate value '${name}'.`);
  33. } else if (type && !type.equals(values.get(name).type)) {
  34. throw new lasagne.Error(`Duplicate value '${name}'.`);
  35. }
  36. return values.get(name);
  37. };
  38. for (const [name] of model.layers) {
  39. const layer = model.layers_[name];
  40. if (layer.input_layer && layer.input_layer.name) {
  41. const input_layer = layer.input_layer;
  42. const dataType = input_layer.input_var && input_layer.input_var.type ? input_layer.input_var.type.dtype : '?';
  43. const shape = layer.input_shape ? new lasagne.TensorShape(layer.input_shape) : null;
  44. const type = shape ? new lasagne.TensorType(dataType, shape) : null;
  45. values.map(input_layer.name, type);
  46. }
  47. }
  48. for (const [name] of model.layers) {
  49. const layer = model.layers_[name];
  50. if (layer && layer.__class__ && layer.__class__.__module__ === 'lasagne.layers.input' && layer.__class__.__name__ === 'InputLayer') {
  51. const shape = new lasagne.TensorShape(layer.shape);
  52. const type = new lasagne.TensorType(layer.input_var.type.dtype, shape);
  53. const argument = new lasagne.Argument(layer.name, [values.map(layer.name, type)]);
  54. this.inputs.push(argument);
  55. continue;
  56. }
  57. this.nodes.push(new lasagne.Node(metadata, layer, values));
  58. }
  59. if (model._output_layer) {
  60. const output_layer = model._output_layer;
  61. this.outputs.push(new lasagne.Argument(output_layer.name, [values.map(output_layer.name)]));
  62. }
  63. }
  64. };
  65. lasagne.Argument = class {
  66. constructor(name, value, type = null) {
  67. this.name = name;
  68. this.value = value;
  69. this.type = type;
  70. }
  71. };
  72. lasagne.Value = class {
  73. constructor(name, type, initializer) {
  74. if (typeof name !== 'string') {
  75. throw new lasagne.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  76. }
  77. this.name = name;
  78. this.type = !type && initializer ? initializer.type : type;
  79. this.initializer = initializer;
  80. }
  81. };
  82. lasagne.Node = class {
  83. constructor(metadata, layer, values) {
  84. this.name = layer.name || '';
  85. const type = layer.__class__ ? `${layer.__class__.__module__}.${layer.__class__.__name__}` : '';
  86. this.type = metadata.type(type) || { name: type };
  87. this.inputs = [];
  88. this.outputs = [];
  89. this.attributes = [];
  90. const params = new Map();
  91. for (const [key, value] of Object.entries(layer)) {
  92. if (key === 'name' || key === 'params' || key === 'input_layer' || key === 'input_shape') {
  93. continue;
  94. }
  95. if (value && value.__class__ && value.__class__.__module__ === 'theano.tensor.sharedvar' && value.__class__.__name__ === 'TensorSharedVariable') {
  96. params.set(value.name, key);
  97. continue;
  98. }
  99. const type = value && value.__class__ ? `${value.__class__.__module__}.${value.__class__.__name__}` : null;
  100. const attribute = new lasagne.Argument(key, value, type);
  101. this.attributes.push(attribute);
  102. }
  103. if (layer.input_layer && layer.input_layer.name) {
  104. const value = values.map(layer.input_layer.name);
  105. const argument = new lasagne.Argument('input', [value]);
  106. this.inputs.push(argument);
  107. }
  108. if (layer.params) {
  109. for (const [param] of layer.params) {
  110. const param_key = params.get(param.name);
  111. if (param_key) {
  112. const initializer = new lasagne.Tensor(param.container.storage[0]);
  113. const argument = new lasagne.Argument(param_key, [values.map(param.name, null, initializer)]);
  114. this.inputs.push(argument);
  115. }
  116. }
  117. }
  118. this.outputs.push(new lasagne.Argument('output', [values.map(this.name)]));
  119. }
  120. };
  121. lasagne.TensorType = class {
  122. constructor(dataType, shape) {
  123. this.dataType = dataType;
  124. this.shape = shape;
  125. }
  126. equals(obj) {
  127. return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
  128. }
  129. toString() {
  130. return this.dataType + this.shape.toString();
  131. }
  132. };
  133. lasagne.TensorShape = class {
  134. constructor(dimensions) {
  135. this.dimensions = dimensions;
  136. }
  137. equals(obj) {
  138. return obj && Array.isArray(obj.dimensions) && Array.isArray(this.dimensions) &&
  139. this.dimensions.length === obj.dimensions.length &&
  140. obj.dimensions.every((value, index) => this.dimensions[index] === value);
  141. }
  142. toString() {
  143. if (this.dimensions && this.dimensions.length > 0) {
  144. return `[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`;
  145. }
  146. return '';
  147. }
  148. };
  149. lasagne.Tensor = class {
  150. constructor(storage) {
  151. this.type = new lasagne.TensorType(storage.dtype.__name__, new lasagne.TensorShape(storage.shape));
  152. this.values = storage.data;
  153. }
  154. };
  155. lasagne.Error = class extends Error {
  156. constructor(message) {
  157. super(message);
  158. this.name = 'Lasagne Error';
  159. }
  160. };
  161. export const ModelFactory = lasagne.ModelFactory;