numpy.js 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. // Experimental
  2. import * as python from './python.js';
  3. const numpy = {};
  4. numpy.ModelFactory = class {
  5. async match(context) {
  6. const stream = context.stream;
  7. const signature = [0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59];
  8. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  9. return context.set('npy');
  10. }
  11. const entries = await context.peek('npz');
  12. if (entries && entries.size > 0) {
  13. return context.set('npz', entries);
  14. }
  15. return null;
  16. }
  17. async open(context) {
  18. let format = '';
  19. const modules = [];
  20. switch (context.type) {
  21. case 'npy': {
  22. format = 'NumPy Array';
  23. const unresolved = new Set();
  24. const execution = new python.Execution();
  25. execution.on('resolve', (sender, name) => unresolved.add(name));
  26. const stream = context.stream;
  27. const io = execution.__import__('io');
  28. const np = execution.__import__('numpy');
  29. const bytes = new io.BytesIO(stream);
  30. const array = np.load(bytes);
  31. if (unresolved.size > 0) {
  32. const name = unresolved.values().next().value;
  33. throw new numpy.Error(`Unknown type name '${name}'.`);
  34. }
  35. const layer = { type: 'numpy.ndarray', parameters: [{ name: 'value', tensor: { name: '', array } }] };
  36. modules.push({ layers: [layer] });
  37. break;
  38. }
  39. case 'npz': {
  40. format = 'NumPy Archive';
  41. const layers = new Map();
  42. const entries = Array.from(context.value);
  43. const separator = entries.every(([name]) => name.endsWith('.weight.npy')) ? '.' : '/';
  44. for (const [key, array] of entries) {
  45. const name = key.replace(/\.npy$/, '');
  46. const path = name.split(separator);
  47. const parameterName = path.pop();
  48. const groupName = path.join(separator);
  49. if (!layers.has(groupName)) {
  50. layers.set(groupName, { name: groupName, parameters: [] });
  51. }
  52. const layer = layers.get(groupName);
  53. layer.parameters.push({
  54. name: parameterName,
  55. tensor: { name, array }
  56. });
  57. }
  58. modules.push({ layers: Array.from(layers.values()) });
  59. break;
  60. }
  61. default: {
  62. throw new numpy.Error(`Unsupported NumPy format '${context.type}'.`);
  63. }
  64. }
  65. return new numpy.Model(format, modules);
  66. }
  67. };
  68. numpy.Model = class {
  69. constructor(format, modules) {
  70. this.format = format;
  71. this.modules = modules.map((module) => new numpy.Module(module));
  72. }
  73. };
  74. numpy.Module = class {
  75. constructor(graph) {
  76. this.name = graph.name || '';
  77. this.nodes = graph.layers.map((layer) => new numpy.Node(layer));
  78. this.inputs = [];
  79. this.outputs = [];
  80. }
  81. };
  82. numpy.Argument = class {
  83. constructor(name, value) {
  84. this.name = name;
  85. this.value = value;
  86. }
  87. };
  88. numpy.Value = class {
  89. constructor(name, initializer = null) {
  90. if (typeof name !== 'string') {
  91. throw new numpy.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  92. }
  93. this.name = name;
  94. this.type = initializer.type;
  95. this.initializer = initializer;
  96. }
  97. };
  98. numpy.Node = class {
  99. constructor(layer) {
  100. this.name = layer.name || '';
  101. this.type = { name: layer.type || 'Object' };
  102. this.inputs = [];
  103. this.outputs = [];
  104. this.attributes = [];
  105. for (const parameter of layer.parameters) {
  106. const initializer = new numpy.Tensor(parameter.tensor.array);
  107. const value = new numpy.Value(parameter.tensor.name || '', initializer);
  108. const argument = new numpy.Argument(parameter.name, [value]);
  109. this.inputs.push(argument);
  110. }
  111. }
  112. };
  113. numpy.Tensor = class {
  114. constructor(array) {
  115. this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
  116. this.stride = array.strides.map((stride) => stride / array.itemsize);
  117. const list = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void';
  118. this.values = list ? array.flatten().tolist() : array.tobytes();
  119. this.encoding = list ? '|' : array.dtype.byteorder;
  120. }
  121. };
  122. numpy.TensorType = class {
  123. constructor(dataType, shape) {
  124. this.dataType = dataType || '?';
  125. this.shape = shape;
  126. }
  127. toString() {
  128. return this.dataType + this.shape.toString();
  129. }
  130. };
  131. numpy.TensorShape = class {
  132. constructor(dimensions) {
  133. this.dimensions = dimensions;
  134. }
  135. toString() {
  136. return this.dimensions && this.dimensions.length > 0 ? `[${this.dimensions.join(',')}]` : '';
  137. }
  138. };
  139. numpy.Utility = class {
  140. static isTensor(obj) {
  141. return obj && obj.__class__ &&
  142. ((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') ||
  143. (obj.__class__.__module__ === 'numpy.core.memmap' && obj.__class__.__name__ === 'memmap'));
  144. }
  145. static weights(obj) {
  146. const dict = (obj, key) => {
  147. const dict = key === '' ? obj : obj[key];
  148. if (dict) {
  149. const weights = new Map();
  150. if (dict instanceof Map) {
  151. for (const [key, obj] of dict) {
  152. if (numpy.Utility.isTensor(obj)) {
  153. weights.set(key, obj);
  154. continue;
  155. } else if (obj instanceof Map && Array.from(obj).every(([, value]) => numpy.Utility.isTensor(value))) {
  156. for (const [name, value] of obj) {
  157. weights.set(`${key}.${name}`, value);
  158. }
  159. continue;
  160. } else if (key === '_metadata') {
  161. continue;
  162. }
  163. return null;
  164. }
  165. return weights;
  166. } else if (!Array.isArray(dict)) {
  167. const set = new Set(['weight_order', 'lr', 'model_iter', '__class__']);
  168. for (const [name, value] of Object.entries(dict)) {
  169. if (numpy.Utility.isTensor(value)) {
  170. weights.set(name, value);
  171. continue;
  172. }
  173. if (set.has(name)) {
  174. continue;
  175. }
  176. if (value && !Array.isArray(value) && Object.entries(value).every(([, value]) => numpy.Utility.isTensor(value))) {
  177. if (value && value.__class__ && value.__class__.__module__ && value.__class__.__name__) {
  178. weights.set(`${name}.__class__`, `${value.__class__.__module__}.${value.__class__.__name__}`);
  179. }
  180. for (const [key, obj] of Object.entries(value)) {
  181. weights.set(`${name}.${key}`, obj);
  182. }
  183. continue;
  184. }
  185. return null;
  186. }
  187. return weights;
  188. }
  189. }
  190. return null;
  191. };
  192. const list = (obj, key) => {
  193. let list = key === '' ? obj : obj[key];
  194. if (list && Array.isArray(list) && list.every((obj) => Object.values(obj).every((value) => numpy.Utility.isTensor(value)))) {
  195. list = list.map((obj) => obj instanceof Map ? obj : new Map(Object.entries(obj)));
  196. }
  197. if (list && Array.isArray(list)) {
  198. const weights = new Map();
  199. for (let i = 0; i < list.length; i++) {
  200. const obj = list[i];
  201. if (numpy.Utility.isTensor(obj)) {
  202. weights.set(i.toString(), obj);
  203. continue;
  204. } else if (obj instanceof Map && Array.from(obj).every(([, value]) => numpy.Utility.isTensor(value))) {
  205. for (const [name, value] of obj) {
  206. weights.set(`${i}.${name}`, value);
  207. }
  208. continue;
  209. }
  210. return null;
  211. }
  212. return weights;
  213. }
  214. return null;
  215. };
  216. const keys = ['', 'blobs', 'model', 'experiment_state'];
  217. for (const key of keys) {
  218. const weights = dict(obj, key);
  219. if (weights && weights.size > 0) {
  220. return weights;
  221. }
  222. }
  223. for (const key of keys) {
  224. const weights = list(obj, key);
  225. if (weights) {
  226. return weights;
  227. }
  228. }
  229. return null;
  230. }
  231. };
  232. numpy.Error = class extends Error {
  233. constructor(message) {
  234. super(message);
  235. this.name = 'Error loading NumPy model.';
  236. }
  237. };
  238. export const ModelFactory = numpy.ModelFactory;