flax.js 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. // Experimental
  2. import * as python from './python.js';
  3. const flax = {};
  4. flax.ModelFactory = class {
  5. async match(context) {
  6. const stream = context.stream;
  7. if (stream.length > 4) {
  8. const buffer = stream.peek(1);
  9. if (buffer[0] === 0xDE || buffer[0] === 0xDF || ((buffer[0] & 0x80) === 0x80)) {
  10. return context.set('flax.msgpack.map');
  11. }
  12. }
  13. return null;
  14. }
  15. async open(context) {
  16. const stream = context.stream;
  17. const packed = stream.peek();
  18. const execution = new python.Execution();
  19. const msgpack = execution.__import__('msgpack');
  20. const numpy = execution.__import__('numpy');
  21. // https://github.com/google/flax/blob/main/flax/serialization.py
  22. const ext_hook = (code, data) => {
  23. switch (code) {
  24. case 1: { // _MsgpackExtType.ndarray
  25. const tuple = msgpack.unpackb(data);
  26. const dtype = new numpy.dtype(tuple[1]);
  27. dtype.byteorder = '<';
  28. return new numpy.ndarray(tuple[0], dtype, tuple[2]);
  29. }
  30. default: {
  31. throw new flax.Error(`Unsupported MessagePack extension '${code}'.`);
  32. }
  33. }
  34. };
  35. const obj = msgpack.unpackb(packed, ext_hook);
  36. return new flax.Model(obj);
  37. }
  38. };
  39. flax.Model = class {
  40. constructor(obj) {
  41. this.format = 'Flax';
  42. this.modules = [new flax.Graph(obj)];
  43. }
  44. };
  45. flax.Graph = class {
  46. constructor(obj) {
  47. this.inputs = [];
  48. this.outputs = [];
  49. const layers = new Map();
  50. const layer = (path) => {
  51. const name = path.join('.');
  52. if (!layers.has(name)) {
  53. layers.set(name, {});
  54. }
  55. return layers.get(name);
  56. };
  57. const flatten = (path, obj) => {
  58. for (const [name, value] of Object.entries(obj)) {
  59. if (flax.Utility.isTensor(value)) {
  60. const obj = layer(path);
  61. obj[name] = value;
  62. } else if (Array.isArray(value)) {
  63. const obj = layer(path);
  64. obj[name] = value;
  65. } else if (Object(value) === value) {
  66. flatten(path.concat(name), value);
  67. } else {
  68. const obj = layer(path);
  69. obj[name] = value;
  70. }
  71. }
  72. };
  73. if (Array.isArray(obj)) {
  74. layer([]).value = obj;
  75. } else {
  76. flatten([], obj);
  77. }
  78. this.nodes = Array.from(layers).map(([name, value]) => new flax.Node(name, value));
  79. }
  80. };
  81. flax.Argument = class {
  82. constructor(name, value) {
  83. this.name = name;
  84. this.value = value;
  85. }
  86. };
  87. flax.Value = class {
  88. constructor(name, initializer = null) {
  89. if (typeof name !== 'string') {
  90. throw new flax.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  91. }
  92. this.name = name;
  93. this.type = initializer ? initializer.type : null;
  94. this.initializer = initializer;
  95. }
  96. };
  97. flax.Node = class {
  98. constructor(name, layer) {
  99. this.name = name;
  100. this.type = { name: 'Module' };
  101. this.attributes = [];
  102. this.inputs = [];
  103. this.outputs = [];
  104. for (const [name, value] of Object.entries(layer)) {
  105. if (flax.Utility.isTensor(value)) {
  106. const tensor = new flax.Tensor(value);
  107. const argument = new flax.Argument(name, [new flax.Value('', tensor)]);
  108. this.inputs.push(argument);
  109. } else if (Array.isArray(value)) {
  110. const attribute = new flax.Argument(name, value);
  111. this.attributes.push(attribute);
  112. } else {
  113. const attribute = new flax.Argument(name, value);
  114. this.attributes.push(attribute);
  115. }
  116. }
  117. }
  118. };
  119. flax.TensorType = class {
  120. constructor(dataType, shape) {
  121. this.dataType = dataType || '?';
  122. this.shape = shape;
  123. }
  124. toString() {
  125. return this.dataType + this.shape.toString();
  126. }
  127. };
  128. flax.TensorShape = class {
  129. constructor(dimensions) {
  130. this.dimensions = dimensions;
  131. }
  132. toString() {
  133. return (Array.isArray(this.dimensions) && this.dimensions.length > 0) ?
  134. `[${this.dimensions.join(',')}]` : '';
  135. }
  136. };
  137. flax.Tensor = class {
  138. constructor(array) {
  139. this.type = new flax.TensorType(array.dtype.__name__, new flax.TensorShape(array.shape));
  140. const dataType = this.type.dataType;
  141. this.encoding = dataType === 'string' || dataType === 'object' ? '|' : array.dtype.byteorder;
  142. this._data = array.tobytes();
  143. this._itemsize = array.dtype.itemsize;
  144. }
  145. get values() {
  146. switch (this.type.dataType) {
  147. case 'string': {
  148. if (this._data instanceof Uint8Array) {
  149. const data = this._data;
  150. const decoder = new TextDecoder('utf-8');
  151. const size = this.type.shape.dimensions.reduce((a, b) => a * b, 1);
  152. this._data = new Array(size);
  153. let offset = 0;
  154. for (let i = 0; i < size; i++) {
  155. const buffer = data.subarray(offset, offset + this._itemsize);
  156. const index = buffer.indexOf(0);
  157. this._data[i] = decoder.decode(index >= 0 ? buffer.subarray(0, index) : buffer);
  158. offset += this._itemsize;
  159. }
  160. }
  161. return this._data;
  162. }
  163. default:
  164. return this._data;
  165. }
  166. }
  167. };
  168. flax.Utility = class {
  169. static isTensor(obj) {
  170. return obj && obj.__class__ && obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray';
  171. }
  172. };
  173. flax.Error = class extends Error {
  174. constructor(message) {
  175. super(message);
  176. this.name = 'Error loading Flax model.';
  177. }
  178. };
  179. export const ModelFactory = flax.ModelFactory;