flax.js 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. // Experimental
  2. var flax = flax || {};
  3. var python = python || require('./python');
  4. flax.ModelFactory = class {
  5. match(context) {
  6. const stream = context.stream;
  7. if (stream.length > 4) {
  8. const code = stream.peek(1)[0];
  9. if (code === 0xDE || code === 0xDF || ((code & 0x80) === 0x80)) {
  10. return 'msgpack.map';
  11. }
  12. }
  13. return '';
  14. }
  15. open(context) {
  16. return context.require('./msgpack').then((msgpack) => {
  17. const stream = context.stream;
  18. const buffer = stream.peek();
  19. const execution = new python.Execution(null);
  20. const reader = msgpack.BinaryReader.open(buffer, (code, data) => {
  21. switch (code) {
  22. case 1: { // _MsgpackExtType.ndarray
  23. const reader = msgpack.BinaryReader.open(data);
  24. const tuple = reader.read();
  25. const dtype = execution.invoke('numpy.dtype', [ tuple[1] ]);
  26. dtype.byteorder = '<';
  27. return execution.invoke('numpy.ndarray', [ tuple[0], dtype, tuple[2] ]);
  28. }
  29. default:
  30. throw new flax.Error("Unsupported MessagePack extension '" + code + "'.");
  31. }
  32. });
  33. const obj = reader.read();
  34. return new flax.Model(obj);
  35. });
  36. }
  37. };
  38. flax.Model = class {
  39. constructor(obj) {
  40. this._graphs = [ new flax.Graph(obj) ];
  41. }
  42. get format() {
  43. return 'Flax';
  44. }
  45. get graphs() {
  46. return this._graphs;
  47. }
  48. };
  49. flax.Graph = class {
  50. constructor(obj) {
  51. const layers = new Map();
  52. const flatten = (path, obj) => {
  53. if (Object.entries(obj).every((entry) => entry[1].__class__ && entry[1].__class__.__module__ === 'numpy' && entry[1].__class__.__name__ === 'ndarray')) {
  54. layers.set(path.join('.'), obj);
  55. }
  56. else {
  57. for (const pair of Object.entries(obj)) {
  58. flatten(path.concat(pair[0]), pair[1]);
  59. }
  60. }
  61. };
  62. flatten([], obj);
  63. this._nodes = Array.from(layers).map((entry) => new flax.Node(entry[0], entry[1]));
  64. }
  65. get inputs() {
  66. return [];
  67. }
  68. get outputs() {
  69. return [];
  70. }
  71. get nodes() {
  72. return this._nodes;
  73. }
  74. };
  75. flax.Parameter = class {
  76. constructor(name, args) {
  77. this._name = name;
  78. this._arguments = args;
  79. }
  80. get name() {
  81. return this._name;
  82. }
  83. get visible() {
  84. return true;
  85. }
  86. get arguments() {
  87. return this._arguments;
  88. }
  89. };
  90. flax.Argument = class {
  91. constructor(name, initializer) {
  92. if (typeof name !== 'string') {
  93. throw new flax.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  94. }
  95. this._name = name;
  96. this._initializer = initializer || null;
  97. }
  98. get name() {
  99. return this._name;
  100. }
  101. get type() {
  102. return this._initializer.type;
  103. }
  104. get initializer() {
  105. return this._initializer;
  106. }
  107. };
  108. flax.Node = class {
  109. constructor(name, weights) {
  110. this._name = name;
  111. this._type = { name: 'Module' };
  112. this._inputs = [];
  113. for (const entry of Object.entries(weights)) {
  114. const name = entry[0];
  115. const tensor = new flax.Tensor(entry[1]);
  116. const argument = new flax.Argument(this._name + '.' + name, tensor);
  117. const parameter = new flax.Parameter(name, [ argument ]);
  118. this._inputs.push(parameter);
  119. }
  120. }
  121. get type() {
  122. return this._type;
  123. }
  124. get name() {
  125. return this._name;
  126. }
  127. get inputs() {
  128. return this._inputs;
  129. }
  130. get outputs() {
  131. return [];
  132. }
  133. get attributes() {
  134. return [];
  135. }
  136. };
  137. flax.TensorType = class {
  138. constructor(dataType, shape) {
  139. this._dataType = dataType;
  140. this._shape = shape;
  141. }
  142. get dataType() {
  143. return this._dataType || '?';
  144. }
  145. get shape() {
  146. return this._shape;
  147. }
  148. toString() {
  149. return this.dataType + this._shape.toString();
  150. }
  151. };
  152. flax.TensorShape = class {
  153. constructor(dimensions) {
  154. this._dimensions = dimensions;
  155. }
  156. get dimensions() {
  157. return this._dimensions;
  158. }
  159. toString() {
  160. if (!this._dimensions || this._dimensions.length == 0) {
  161. return '';
  162. }
  163. return '[' + this._dimensions.join(',') + ']';
  164. }
  165. };
  166. flax.Tensor = class {
  167. constructor(array) {
  168. this._type = new flax.TensorType(array.dtype.name, new flax.TensorShape(array.shape));
  169. this._data = array.tobytes();
  170. this._byteorder = array.dtype.byteorder;
  171. this._itemsize = array.dtype.itemsize;
  172. }
  173. get type() {
  174. return this._type;
  175. }
  176. get state() {
  177. return this._context().state;
  178. }
  179. get value() {
  180. const context = this._context();
  181. if (context.state) {
  182. return null;
  183. }
  184. context.limit = Number.MAX_SAFE_INTEGER;
  185. return this._decode(context, 0);
  186. }
  187. toString() {
  188. const context = this._context();
  189. if (context.state) {
  190. return '';
  191. }
  192. context.limit = 10000;
  193. const value = this._decode(context, 0);
  194. return flax.Tensor._stringify(value, '', ' ');
  195. }
  196. _context() {
  197. const context = {};
  198. context.index = 0;
  199. context.count = 0;
  200. context.state = null;
  201. if (this._byteorder !== '<' && this._byteorder !== '>' && this._type.dataType !== 'uint8' && this._type.dataType !== 'int8') {
  202. context.state = 'Tensor byte order is not supported.';
  203. return context;
  204. }
  205. if (!this._data || this._data.length == 0) {
  206. context.state = 'Tensor data is empty.';
  207. return context;
  208. }
  209. context.itemSize = this._itemsize;
  210. context.dimensions = this._type.shape.dimensions;
  211. context.dataType = this._type.dataType;
  212. context.littleEndian = this._byteorder == '<';
  213. context.data = this._data;
  214. context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  215. return context;
  216. }
  217. _decode(context, dimension) {
  218. const littleEndian = context.littleEndian;
  219. const shape = context.dimensions.length == 0 ? [ 1 ] : context.dimensions;
  220. const results = [];
  221. const size = shape[dimension];
  222. if (dimension == shape.length - 1) {
  223. for (let i = 0; i < size; i++) {
  224. if (context.count > context.limit) {
  225. results.push('...');
  226. return results;
  227. }
  228. if (context.rawData) {
  229. switch (context.dataType) {
  230. case 'float16':
  231. results.push(context.rawData.getFloat16(context.index, littleEndian));
  232. break;
  233. case 'float32':
  234. results.push(context.rawData.getFloat32(context.index, littleEndian));
  235. break;
  236. case 'float64':
  237. results.push(context.rawData.getFloat64(context.index, littleEndian));
  238. break;
  239. case 'int8':
  240. results.push(context.rawData.getInt8(context.index, littleEndian));
  241. break;
  242. case 'int16':
  243. results.push(context.rawData.getInt16(context.index, littleEndian));
  244. break;
  245. case 'int32':
  246. results.push(context.rawData.getInt32(context.index, littleEndian));
  247. break;
  248. case 'int64':
  249. results.push(context.rawData.getInt64(context.index, littleEndian));
  250. break;
  251. case 'uint8':
  252. results.push(context.rawData.getUint8(context.index, littleEndian));
  253. break;
  254. case 'uint16':
  255. results.push(context.rawData.getUint16(context.index, littleEndian));
  256. break;
  257. case 'uint32':
  258. results.push(context.rawData.getUint32(context.index, littleEndian));
  259. break;
  260. default:
  261. throw new flax.Error("Unsupported tensor data type '" + context.dataType + "'.");
  262. }
  263. context.index += context.itemSize;
  264. context.count++;
  265. }
  266. }
  267. }
  268. else {
  269. for (let j = 0; j < size; j++) {
  270. if (context.count > context.limit) {
  271. results.push('...');
  272. return results;
  273. }
  274. results.push(this._decode(context, dimension + 1));
  275. }
  276. }
  277. if (context.dimensions.length == 0) {
  278. return results[0];
  279. }
  280. return results;
  281. }
  282. static _stringify(value, indentation, indent) {
  283. if (Array.isArray(value)) {
  284. const result = [];
  285. result.push(indentation + '[');
  286. const items = value.map((item) => flax.Tensor._stringify(item, indentation + indent, indent));
  287. if (items.length > 0) {
  288. result.push(items.join(',\n'));
  289. }
  290. result.push(indentation + ']');
  291. return result.join('\n');
  292. }
  293. if (typeof value == 'string') {
  294. return indentation + value;
  295. }
  296. if (value == Infinity) {
  297. return indentation + 'Infinity';
  298. }
  299. if (value == -Infinity) {
  300. return indentation + '-Infinity';
  301. }
  302. if (isNaN(value)) {
  303. return indentation + 'NaN';
  304. }
  305. return indentation + value.toString();
  306. }
  307. };
  308. flax.Error = class extends Error {
  309. constructor(message) {
  310. super(message);
  311. this.name = 'Error loading Flax model.';
  312. }
  313. };
  314. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  315. module.exports.ModelFactory = flax.ModelFactory;
  316. }