flax.js 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. // Experimental
  2. var flax = flax || {};
  3. var python = python || require('./python');
  4. flax.ModelFactory = class {
  5. match(context) {
  6. const stream = context.stream;
  7. if (stream.length > 4) {
  8. const code = stream.peek(1)[0];
  9. if (code === 0xDE || code === 0xDF || ((code & 0x80) === 0x80)) {
  10. return 'msgpack.map';
  11. }
  12. }
  13. return '';
  14. }
  15. open(context) {
  16. return context.require('./msgpack').then((msgpack) => {
  17. const stream = context.stream;
  18. const buffer = stream.peek();
  19. const execution = new python.Execution(null);
  20. const reader = msgpack.BinaryReader.open(buffer, (code, data) => {
  21. switch (code) {
  22. case 1: { // _MsgpackExtType.ndarray
  23. const reader = msgpack.BinaryReader.open(data);
  24. const tuple = reader.read();
  25. const dtype = execution.invoke('numpy.dtype', [ tuple[1] ]);
  26. dtype.byteorder = '<';
  27. return execution.invoke('numpy.ndarray', [ tuple[0], dtype, tuple[2] ]);
  28. }
  29. default: {
  30. throw new flax.Error("Unsupported MessagePack extension '" + code + "'.");
  31. }
  32. }
  33. });
  34. const obj = reader.read();
  35. return new flax.Model(obj);
  36. });
  37. }
  38. };
  39. flax.Model = class {
  40. constructor(obj) {
  41. this._graphs = [ new flax.Graph(obj) ];
  42. }
  43. get format() {
  44. return 'Flax';
  45. }
  46. get graphs() {
  47. return this._graphs;
  48. }
  49. };
  50. flax.Graph = class {
  51. constructor(obj) {
  52. const layers = new Map();
  53. const flatten = (path, obj) => {
  54. if (Object.entries(obj).every((entry) => entry[1].__class__ && entry[1].__class__.__module__ === 'numpy' && entry[1].__class__.__name__ === 'ndarray')) {
  55. layers.set(path.join('.'), obj);
  56. }
  57. else {
  58. for (const pair of Object.entries(obj)) {
  59. flatten(path.concat(pair[0]), pair[1]);
  60. }
  61. }
  62. };
  63. flatten([], obj);
  64. this._nodes = Array.from(layers).map((entry) => new flax.Node(entry[0], entry[1]));
  65. }
  66. get inputs() {
  67. return [];
  68. }
  69. get outputs() {
  70. return [];
  71. }
  72. get nodes() {
  73. return this._nodes;
  74. }
  75. };
  76. flax.Parameter = class {
  77. constructor(name, args) {
  78. this._name = name;
  79. this._arguments = args;
  80. }
  81. get name() {
  82. return this._name;
  83. }
  84. get visible() {
  85. return true;
  86. }
  87. get arguments() {
  88. return this._arguments;
  89. }
  90. };
  91. flax.Argument = class {
  92. constructor(name, initializer) {
  93. if (typeof name !== 'string') {
  94. throw new flax.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  95. }
  96. this._name = name;
  97. this._initializer = initializer || null;
  98. }
  99. get name() {
  100. return this._name;
  101. }
  102. get type() {
  103. return this._initializer.type;
  104. }
  105. get initializer() {
  106. return this._initializer;
  107. }
  108. };
  109. flax.Node = class {
  110. constructor(name, weights) {
  111. this._name = name;
  112. this._type = { name: 'Module' };
  113. this._inputs = [];
  114. for (const entry of Object.entries(weights)) {
  115. const name = entry[0];
  116. const tensor = new flax.Tensor(entry[1]);
  117. const argument = new flax.Argument(this._name + '.' + name, tensor);
  118. const parameter = new flax.Parameter(name, [ argument ]);
  119. this._inputs.push(parameter);
  120. }
  121. }
  122. get type() {
  123. return this._type;
  124. }
  125. get name() {
  126. return this._name;
  127. }
  128. get inputs() {
  129. return this._inputs;
  130. }
  131. get outputs() {
  132. return [];
  133. }
  134. get attributes() {
  135. return [];
  136. }
  137. };
  138. flax.TensorType = class {
  139. constructor(dataType, shape) {
  140. this._dataType = dataType;
  141. this._shape = shape;
  142. }
  143. get dataType() {
  144. return this._dataType || '?';
  145. }
  146. get shape() {
  147. return this._shape;
  148. }
  149. toString() {
  150. return this.dataType + this._shape.toString();
  151. }
  152. };
  153. flax.TensorShape = class {
  154. constructor(dimensions) {
  155. this._dimensions = dimensions;
  156. }
  157. get dimensions() {
  158. return this._dimensions;
  159. }
  160. toString() {
  161. if (!this._dimensions || this._dimensions.length == 0) {
  162. return '';
  163. }
  164. return '[' + this._dimensions.join(',') + ']';
  165. }
  166. };
  167. flax.Tensor = class {
  168. constructor(array) {
  169. this._type = new flax.TensorType(array.dtype.__name__, new flax.TensorShape(array.shape));
  170. this._data = array.tobytes();
  171. this._byteorder = array.dtype.byteorder;
  172. this._itemsize = array.dtype.itemsize;
  173. }
  174. get type() {
  175. return this._type;
  176. }
  177. get state() {
  178. return this._context().state;
  179. }
  180. get value() {
  181. const context = this._context();
  182. if (context.state) {
  183. return null;
  184. }
  185. context.limit = Number.MAX_SAFE_INTEGER;
  186. return this._decode(context, 0);
  187. }
  188. toString() {
  189. const context = this._context();
  190. if (context.state) {
  191. return '';
  192. }
  193. context.limit = 10000;
  194. const value = this._decode(context, 0);
  195. return flax.Tensor._stringify(value, '', ' ');
  196. }
  197. _context() {
  198. const context = {};
  199. context.index = 0;
  200. context.count = 0;
  201. context.state = null;
  202. if (this._byteorder !== '<' && this._byteorder !== '>' && this._type.dataType !== 'uint8' && this._type.dataType !== 'int8') {
  203. context.state = 'Tensor byte order is not supported.';
  204. return context;
  205. }
  206. if (!this._data || this._data.length == 0) {
  207. context.state = 'Tensor data is empty.';
  208. return context;
  209. }
  210. context.itemSize = this._itemsize;
  211. context.dimensions = this._type.shape.dimensions;
  212. context.dataType = this._type.dataType;
  213. context.littleEndian = this._byteorder == '<';
  214. context.data = this._data;
  215. context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  216. return context;
  217. }
  218. _decode(context, dimension) {
  219. const littleEndian = context.littleEndian;
  220. const shape = context.dimensions.length == 0 ? [ 1 ] : context.dimensions;
  221. const results = [];
  222. const size = shape[dimension];
  223. if (dimension == shape.length - 1) {
  224. for (let i = 0; i < size; i++) {
  225. if (context.count > context.limit) {
  226. results.push('...');
  227. return results;
  228. }
  229. if (context.rawData) {
  230. switch (context.dataType) {
  231. case 'float16':
  232. results.push(context.rawData.getFloat16(context.index, littleEndian));
  233. break;
  234. case 'float32':
  235. results.push(context.rawData.getFloat32(context.index, littleEndian));
  236. break;
  237. case 'float64':
  238. results.push(context.rawData.getFloat64(context.index, littleEndian));
  239. break;
  240. case 'int8':
  241. results.push(context.rawData.getInt8(context.index, littleEndian));
  242. break;
  243. case 'int16':
  244. results.push(context.rawData.getInt16(context.index, littleEndian));
  245. break;
  246. case 'int32':
  247. results.push(context.rawData.getInt32(context.index, littleEndian));
  248. break;
  249. case 'int64':
  250. results.push(context.rawData.getInt64(context.index, littleEndian));
  251. break;
  252. case 'uint8':
  253. results.push(context.rawData.getUint8(context.index, littleEndian));
  254. break;
  255. case 'uint16':
  256. results.push(context.rawData.getUint16(context.index, littleEndian));
  257. break;
  258. case 'uint32':
  259. results.push(context.rawData.getUint32(context.index, littleEndian));
  260. break;
  261. default:
  262. throw new flax.Error("Unsupported tensor data type '" + context.dataType + "'.");
  263. }
  264. context.index += context.itemSize;
  265. context.count++;
  266. }
  267. }
  268. }
  269. else {
  270. for (let j = 0; j < size; j++) {
  271. if (context.count > context.limit) {
  272. results.push('...');
  273. return results;
  274. }
  275. results.push(this._decode(context, dimension + 1));
  276. }
  277. }
  278. if (context.dimensions.length == 0) {
  279. return results[0];
  280. }
  281. return results;
  282. }
  283. static _stringify(value, indentation, indent) {
  284. if (Array.isArray(value)) {
  285. const result = [];
  286. result.push(indentation + '[');
  287. const items = value.map((item) => flax.Tensor._stringify(item, indentation + indent, indent));
  288. if (items.length > 0) {
  289. result.push(items.join(',\n'));
  290. }
  291. result.push(indentation + ']');
  292. return result.join('\n');
  293. }
  294. if (typeof value == 'string') {
  295. return indentation + value;
  296. }
  297. if (value == Infinity) {
  298. return indentation + 'Infinity';
  299. }
  300. if (value == -Infinity) {
  301. return indentation + '-Infinity';
  302. }
  303. if (isNaN(value)) {
  304. return indentation + 'NaN';
  305. }
  306. return indentation + value.toString();
  307. }
  308. };
  309. flax.Error = class extends Error {
  310. constructor(message) {
  311. super(message);
  312. this.name = 'Error loading Flax model.';
  313. }
  314. };
  315. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  316. module.exports.ModelFactory = flax.ModelFactory;
  317. }