2
0

flax.js 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. // Experimental
  2. import * as python from './python.js';
  3. const flax = {};
  4. flax.ModelFactory = class {
  5. async match(context) {
  6. const stream = context.stream;
  7. if (stream.length > 4) {
  8. const buffer = stream.peek(1);
  9. if (buffer[0] === 0xDE || buffer[0] === 0xDF || ((buffer[0] & 0x80) === 0x80)) {
  10. return context.set('flax.msgpack.map');
  11. }
  12. }
  13. return null;
  14. }
  15. async open(context) {
  16. const stream = context.stream;
  17. const packed = stream.peek();
  18. const execution = new python.Execution();
  19. // https://github.com/google/flax/blob/main/flax/serialization.py
  20. const ext_hook = (code, data) => {
  21. switch (code) {
  22. case 1: { // _MsgpackExtType.ndarray
  23. const tuple = execution.invoke('msgpack.unpackb', [data]);
  24. const dtype = execution.invoke('numpy.dtype', [tuple[1]]);
  25. dtype.byteorder = '<';
  26. return execution.invoke('numpy.ndarray', [tuple[0], dtype, tuple[2]]);
  27. }
  28. default: {
  29. throw new flax.Error(`Unsupported MessagePack extension '${code}'.`);
  30. }
  31. }
  32. };
  33. const obj = execution.invoke('msgpack.unpackb', [packed, ext_hook]);
  34. return new flax.Model(obj);
  35. }
  36. };
  37. flax.Model = class {
  38. constructor(obj) {
  39. this.format = 'Flax';
  40. this.graphs = [new flax.Graph(obj)];
  41. }
  42. };
  43. flax.Graph = class {
  44. constructor(obj) {
  45. this.inputs = [];
  46. this.outputs = [];
  47. const layers = new Map();
  48. const layer = (path) => {
  49. const name = path.join('.');
  50. if (!layers.has(name)) {
  51. layers.set(name, {});
  52. }
  53. return layers.get(name);
  54. };
  55. const flatten = (path, obj) => {
  56. for (const [name, value] of Object.entries(obj)) {
  57. if (flax.Utility.isTensor(value)) {
  58. const obj = layer(path);
  59. obj[name] = value;
  60. } else if (Array.isArray(value)) {
  61. const obj = layer(path);
  62. obj[name] = value;
  63. } else if (Object(value) === value) {
  64. flatten(path.concat(name), value);
  65. } else {
  66. const obj = layer(path);
  67. obj[name] = value;
  68. }
  69. }
  70. };
  71. if (Array.isArray(obj)) {
  72. layer([]).value = obj;
  73. } else {
  74. flatten([], obj);
  75. }
  76. this.nodes = Array.from(layers).map(([name, value]) => new flax.Node(name, value));
  77. }
  78. };
  79. flax.Argument = class {
  80. constructor(name, value) {
  81. this.name = name;
  82. this.value = value;
  83. }
  84. };
  85. flax.Value = class {
  86. constructor(name, initializer) {
  87. if (typeof name !== 'string') {
  88. throw new flax.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  89. }
  90. this.name = name;
  91. this.type = initializer ? initializer.type : null;
  92. this.initializer = initializer || null;
  93. }
  94. };
  95. flax.Node = class {
  96. constructor(name, layer) {
  97. this.name = name;
  98. this.type = { name: 'Module' };
  99. this.attributes = [];
  100. this.inputs = [];
  101. this.outputs = [];
  102. for (const [name, value] of Object.entries(layer)) {
  103. if (flax.Utility.isTensor(value)) {
  104. const tensor = new flax.Tensor(value);
  105. const argument = new flax.Argument(name, [new flax.Value('', tensor)]);
  106. this.inputs.push(argument);
  107. } else if (Array.isArray(value)) {
  108. const attribute = new flax.Argument(name, value);
  109. this.attributes.push(attribute);
  110. } else {
  111. const attribute = new flax.Argument(name, value);
  112. this.attributes.push(attribute);
  113. }
  114. }
  115. }
  116. };
  117. flax.TensorType = class {
  118. constructor(dataType, shape) {
  119. this.dataType = dataType || '?';
  120. this.shape = shape;
  121. }
  122. toString() {
  123. return this.dataType + this.shape.toString();
  124. }
  125. };
  126. flax.TensorShape = class {
  127. constructor(dimensions) {
  128. this.dimensions = dimensions;
  129. }
  130. toString() {
  131. return (Array.isArray(this.dimensions) && this.dimensions.length > 0) ?
  132. `[${this.dimensions.join(',')}]` : '';
  133. }
  134. };
  135. flax.Tensor = class {
  136. constructor(array) {
  137. this.type = new flax.TensorType(array.dtype.__name__, new flax.TensorShape(array.shape));
  138. const dataType = this.type.dataType;
  139. this.encoding = dataType === 'string' || dataType === 'object' ? '|' : array.dtype.byteorder;
  140. this._data = array.tobytes();
  141. this._itemsize = array.dtype.itemsize;
  142. }
  143. get values() {
  144. switch (this.type.dataType) {
  145. case 'string': {
  146. if (this._data instanceof Uint8Array) {
  147. const data = this._data;
  148. const decoder = new TextDecoder('utf-8');
  149. const size = this.type.shape.dimensions.reduce((a, b) => a * b, 1);
  150. this._data = new Array(size);
  151. let offset = 0;
  152. for (let i = 0; i < size; i++) {
  153. const buffer = data.subarray(offset, offset + this._itemsize);
  154. const index = buffer.indexOf(0);
  155. this._data[i] = decoder.decode(index >= 0 ? buffer.subarray(0, index) : buffer);
  156. offset += this._itemsize;
  157. }
  158. }
  159. return this._data;
  160. }
  161. default:
  162. return this._data;
  163. }
  164. }
  165. };
  166. flax.Utility = class {
  167. static isTensor(obj) {
  168. return obj && obj.__class__ && obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray';
  169. }
  170. };
  171. flax.Error = class extends Error {
  172. constructor(message) {
  173. super(message);
  174. this.name = 'Error loading Flax model.';
  175. }
  176. };
  177. export const ModelFactory = flax.ModelFactory;