numpy.js 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. // Experimental
  2. import * as python from './python.js';
  3. const numpy = {};
  4. numpy.ModelFactory = class {
  5. async match(context) {
  6. const stream = context.stream;
  7. const signature = [0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59];
  8. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  9. return context.set('npy');
  10. }
  11. const entries = await context.peek('npz');
  12. if (entries && entries.size > 0) {
  13. return context.set('npz', entries);
  14. }
  15. return null;
  16. }
  17. async open(context) {
  18. let format = '';
  19. const graphs = [];
  20. switch (context.type) {
  21. case 'npy': {
  22. format = 'NumPy Array';
  23. const unresolved = new Set();
  24. const execution = new python.Execution();
  25. execution.on('resolve', (_, name) => unresolved.add(name));
  26. const stream = context.stream;
  27. const bytes = execution.invoke('io.BytesIO', [stream]);
  28. const array = execution.invoke('numpy.load', [bytes]);
  29. if (unresolved.size > 0) {
  30. const name = unresolved.values().next().value;
  31. throw new numpy.Error(`Unknown type name '${name}'.`);
  32. }
  33. const layer = { type: 'numpy.ndarray', parameters: [{ name: 'value', tensor: { name: '', array } }] };
  34. graphs.push({ layers: [layer] });
  35. break;
  36. }
  37. case 'npz': {
  38. format = 'NumPy Zip';
  39. const layers = new Map();
  40. const entries = Array.from(context.value);
  41. const separator = entries.every(([name]) => name.endsWith('.weight.npy')) ? '.' : '/';
  42. for (const [key, array] of entries) {
  43. const name = key.replace(/\.npy$/, '');
  44. const path = name.split(separator);
  45. const parameterName = path.pop();
  46. const groupName = path.join(separator);
  47. if (!layers.has(groupName)) {
  48. layers.set(groupName, { name: groupName, parameters: [] });
  49. }
  50. const layer = layers.get(groupName);
  51. layer.parameters.push({
  52. name: parameterName,
  53. tensor: { name, array }
  54. });
  55. }
  56. graphs.push({ layers: Array.from(layers.values()) });
  57. break;
  58. }
  59. default: {
  60. throw new numpy.Error(`Unsupported NumPy format '${context.type}'.`);
  61. }
  62. }
  63. return new numpy.Model(format, graphs);
  64. }
  65. };
  66. numpy.Model = class {
  67. constructor(format, graphs) {
  68. this.format = format;
  69. this.graphs = graphs.map((graph) => new numpy.Graph(graph));
  70. }
  71. };
  72. numpy.Graph = class {
  73. constructor(graph) {
  74. this.name = graph.name || '';
  75. this.nodes = graph.layers.map((layer) => new numpy.Node(layer));
  76. this.inputs = [];
  77. this.outputs = [];
  78. }
  79. };
  80. numpy.Argument = class {
  81. constructor(name, value) {
  82. this.name = name;
  83. this.value = value;
  84. }
  85. };
  86. numpy.Value = class {
  87. constructor(name, initializer) {
  88. if (typeof name !== 'string') {
  89. throw new numpy.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  90. }
  91. this.name = name;
  92. this.type = initializer.type;
  93. this.initializer = initializer || null;
  94. }
  95. };
  96. numpy.Node = class {
  97. constructor(layer) {
  98. this.name = layer.name || '';
  99. this.type = { name: layer.type || 'Object' };
  100. this.inputs = [];
  101. this.outputs = [];
  102. this.attributes = [];
  103. for (const parameter of layer.parameters) {
  104. const initializer = new numpy.Tensor(parameter.tensor.array);
  105. const value = new numpy.Value(parameter.tensor.name || '', initializer);
  106. const argument = new numpy.Argument(parameter.name, [value]);
  107. this.inputs.push(argument);
  108. }
  109. }
  110. };
  111. numpy.Tensor = class {
  112. constructor(array) {
  113. this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
  114. this.stride = array.strides.map((stride) => stride / array.itemsize);
  115. this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
  116. this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
  117. }
  118. };
  119. numpy.TensorType = class {
  120. constructor(dataType, shape) {
  121. this.dataType = dataType || '?';
  122. this.shape = shape;
  123. }
  124. toString() {
  125. return this.dataType + this.shape.toString();
  126. }
  127. };
  128. numpy.TensorShape = class {
  129. constructor(dimensions) {
  130. this.dimensions = dimensions;
  131. }
  132. toString() {
  133. return this.dimensions && this.dimensions.length > 0 ? `[${this.dimensions.join(',')}]` : '';
  134. }
  135. };
  136. numpy.Utility = class {
  137. static isTensor(obj) {
  138. return obj && obj.__class__ &&
  139. ((obj.__class__.__module__ === 'numpy' && obj.__class__.__name__ === 'ndarray') ||
  140. (obj.__class__.__module__ === 'numpy.core.memmap' && obj.__class__.__name__ === 'memmap'));
  141. }
  142. static weights(obj) {
  143. const dict = (obj, key) => {
  144. const dict = key === '' ? obj : obj[key];
  145. if (dict) {
  146. const weights = new Map();
  147. if (dict instanceof Map) {
  148. for (const [key, obj] of dict) {
  149. if (numpy.Utility.isTensor(obj)) {
  150. weights.set(key, obj);
  151. continue;
  152. } else if (obj instanceof Map && Array.from(obj).every(([, value]) => numpy.Utility.isTensor(value))) {
  153. for (const [name, value] of obj) {
  154. weights.set(`${key}.${name}`, value);
  155. }
  156. continue;
  157. } else if (key === '_metadata') {
  158. continue;
  159. }
  160. return null;
  161. }
  162. return weights;
  163. } else if (!Array.isArray(dict)) {
  164. const set = new Set(['weight_order', 'lr', 'model_iter', '__class__']);
  165. for (const [name, value] of Object.entries(dict)) {
  166. if (numpy.Utility.isTensor(value)) {
  167. weights.set(name, value);
  168. continue;
  169. }
  170. if (set.has(name)) {
  171. continue;
  172. }
  173. if (value && !Array.isArray(value) && Object.entries(value).every(([, value]) => numpy.Utility.isTensor(value))) {
  174. if (value && value.__class__ && value.__class__.__module__ && value.__class__.__name__) {
  175. weights.set(`${name}.__class__`, `${value.__class__.__module__}.${value.__class__.__name__}`);
  176. }
  177. for (const [name, obj] of Object.entries(value)) {
  178. weights.set(`${name}.${name}`, obj);
  179. }
  180. continue;
  181. }
  182. return null;
  183. }
  184. return weights;
  185. }
  186. }
  187. return null;
  188. };
  189. const list = (obj, key) => {
  190. let list = key === '' ? obj : obj[key];
  191. if (list && Array.isArray(list) && list.every((obj) => Object.values(obj).every((value) => numpy.Utility.isTensor(value)))) {
  192. list = list.map((obj) => obj instanceof Map ? obj : new Map(Object.entries(obj)));
  193. }
  194. if (list && Array.isArray(list)) {
  195. const weights = new Map();
  196. for (let i = 0; i < list.length; i++) {
  197. const obj = list[i];
  198. if (numpy.Utility.isTensor(obj)) {
  199. weights.set(i.toString(), obj);
  200. continue;
  201. } else if (obj instanceof Map && Array.from(obj).every(([, value]) => numpy.Utility.isTensor(value))) {
  202. for (const [name, value] of obj) {
  203. weights.set(`${i}.${name}`, value);
  204. }
  205. continue;
  206. }
  207. return null;
  208. }
  209. return weights;
  210. }
  211. return null;
  212. };
  213. const keys = ['', 'blobs', 'model', 'experiment_state'];
  214. for (const key of keys) {
  215. const weights = dict(obj, key);
  216. if (weights && weights.size > 0) {
  217. return weights;
  218. }
  219. }
  220. for (const key of keys) {
  221. const weights = list(obj, key);
  222. if (weights) {
  223. return weights;
  224. }
  225. }
  226. return null;
  227. }
  228. };
  229. numpy.Error = class extends Error {
  230. constructor(message) {
  231. super(message);
  232. this.name = 'Error loading Chainer model.';
  233. }
  234. };
  235. export const ModelFactory = numpy.ModelFactory;