nnabla.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. const nnabla = {};
  2. nnabla.ModelFactory = class {
  3. async match(context) {
  4. const identifier = context.identifier;
  5. if (identifier.endsWith('.nntxt')) {
  6. const tags = await context.tags('pbtxt');
  7. if (tags.has('network')) {
  8. return context.set('nnabla.pbtxt');
  9. }
  10. }
  11. return null;
  12. }
  13. async open(context) {
  14. nnabla.proto = await context.require('./nnabla-proto');
  15. nnabla.proto = nnabla.proto.nnabla;
  16. switch (context.type) {
  17. case 'nnabla.pbtxt': {
  18. const reader = await context.read('protobuf.text');
  19. const model = nnabla.proto.NNablaProtoBuf.decodeText(reader);
  20. const files = ['nnp_version.txt', 'parameter.protobuf', 'parameter.h5'];
  21. let contexts = await Promise.all(files.map((file) => context.fetch(file).catch(() => null)));
  22. contexts = contexts.filter((context) => context !== null);
  23. contexts = new Map(contexts.map((context) => [context.identifier, context]));
  24. let version = '';
  25. if (contexts.has('nnp_version.txt')) {
  26. const context = contexts.get('nnp_version.txt');
  27. const reader = await context.read('text');
  28. const line = reader.read('\n');
  29. version = line.split('\r').shift();
  30. }
  31. if (contexts.has('parameter.protobuf')) {
  32. const context = contexts.get('parameter.protobuf');
  33. const reader = await context.read('protobuf.binary');
  34. const params = nnabla.proto.NNablaProtoBuf.decode(reader);
  35. model.parameter = params.parameter;
  36. } else if (contexts.has('parameter.h5')) {
  37. const context = contexts.get('parameter.h5');
  38. const file = await context.read('hdf5');
  39. const queue = [['',file]];
  40. while (queue.length > 0) {
  41. const [name, group] = queue.shift();
  42. if (group.value) {
  43. const variable = group.value;
  44. const data = variable.data.peek();
  45. const buffer = new Uint8Array(data.length);
  46. buffer.set(data, 0);
  47. const parameter = new nnabla.proto.Parameter();
  48. parameter.variable_name = name;
  49. parameter.shape = new nnabla.proto.Shape();
  50. parameter.shape.dim = variable.shape.map((dim) => BigInt(dim));
  51. parameter.data = new Float32Array(buffer.buffer, buffer.byteOffset, buffer.byteLength >> 2);
  52. model.parameter.push(parameter);
  53. } else {
  54. for (const [key, value] of group.groups) {
  55. queue.push([name ? `${name}/${key}` : key, value]);
  56. }
  57. }
  58. }
  59. }
  60. const metadata = await context.metadata('nnabla-metadata.json');
  61. return new nnabla.Model(metadata, model, version);
  62. }
  63. default: {
  64. throw new nnabla.Error(`Unsupported nnabla format '${context.type}'.`);
  65. }
  66. }
  67. }
  68. filter(context, match) {
  69. return context.type !== 'nnabla.pbtxt' || (match.type !== 'hdf5.parameter.h5' && match.type !== 'keras.h5');
  70. }
  71. };
  72. nnabla.Model = class {
  73. constructor(metadata, model, version) {
  74. this.format = `NNabla${version ? ` v${version}` : ''}`;
  75. this.modules = [];
  76. const tensors = new Map(model.parameter.map((parameter) => {
  77. const name = parameter.variable_name;
  78. const shape = new nnabla.TensorShape(parameter.shape.dim);
  79. const type = new nnabla.TensorType(shape);
  80. return [name, new nnabla.Tensor(name, type, parameter.data)];
  81. }));
  82. const networks = new Map(model.network.map((network) => [network.name, network]));
  83. for (const executor of model.executor) {
  84. const network = networks.get(executor.network_name);
  85. const graph = new nnabla.Graph(metadata, network, executor.data_variable, executor.output_variable, tensors);
  86. this.modules.push(graph);
  87. }
  88. for (const optimizer of model.optimizer) {
  89. const network = networks.get(optimizer.network_name);
  90. const graph = new nnabla.Graph(metadata, network, optimizer.data_variable, optimizer.loss_variable, tensors);
  91. this.modules.push(graph);
  92. }
  93. for (const monitor of model.monitor) {
  94. const network = networks.get(monitor.network_name);
  95. const graph = new nnabla.Graph(metadata, network, monitor.data_variable, monitor.monitor_variable, tensors);
  96. this.modules.push(graph);
  97. }
  98. }
  99. };
  100. nnabla.Graph = class {
  101. constructor (metadata, network, inputs, outputs, tensors) {
  102. this.name = network.name;
  103. const values = new Map(network.variable.map((variable) => {
  104. const name = variable.name;
  105. const shape = new nnabla.TensorShape(variable.shape.dim);
  106. const type = new nnabla.TensorType(shape);
  107. return [name, new nnabla.Value(name, type, tensors.get(name))];
  108. }));
  109. values.map = (name) => {
  110. if (!values.has(name)) {
  111. values.set(name, new nnabla.Value(name, null, tensors.get(name)));
  112. }
  113. return values.get(name);
  114. };
  115. this.inputs = inputs.map((item) => {
  116. const name = item.variable_name;
  117. return new nnabla.Argument(name, [values.map(name)]);
  118. });
  119. this.outputs = outputs.map((output) => {
  120. const name = output.variable_name;
  121. return new nnabla.Argument(name, [values.map(name)]);
  122. });
  123. const get_parameters = (func) => {
  124. for (const [key, value] of Object.entries(func)) {
  125. if (key.endsWith("_param")) {
  126. return value;
  127. }
  128. }
  129. return undefined;
  130. };
  131. this.nodes = network.function.map((func) => {
  132. const parameters = get_parameters(func) || [];
  133. const attributes = Object.entries(parameters).map(([name, value]) => {
  134. const attribute = metadata.attribute(func.type, name);
  135. let type = attribute.type;
  136. switch (type) {
  137. case 'shape':
  138. type = "int64[]";
  139. value = value.dim;
  140. break;
  141. default:
  142. break;
  143. }
  144. const visible = attribute.default !== undefined && value === attribute.default ? false : true;
  145. return new nnabla.Argument(name, value, type, visible);
  146. });
  147. const func_type = metadata.type(func.type);
  148. const inputs = [];
  149. for (let index = 0; index < func.input.length;) {
  150. const input = func_type.inputs && index < func_type.inputs.length ? func_type.inputs[index] : { name: index.toString() };
  151. const count = input.list ? func.input.length - index : 1;
  152. const args = func.input.slice(index, index + count).map((input) => values.map(input));
  153. const argument = new nnabla.Argument(input.name, args);
  154. inputs.push(argument);
  155. index += count;
  156. }
  157. const outputs = [];
  158. for (let index = 0; index < func.output.length;) {
  159. const output = func_type.outputs && index < func_type.outputs.length ? func_type.outputs[index] : { name: index.toString() };
  160. const count = output.list ? func.output.length - index : 1;
  161. const args = func.output.slice(index, index + count).map((output) => values.map(output));
  162. const argument = new nnabla.Argument(output.name, args);
  163. outputs.push(argument);
  164. index += count;
  165. }
  166. return new nnabla.Node(metadata, func, attributes, inputs, outputs);
  167. });
  168. }
  169. };
  170. nnabla.Argument = class {
  171. constructor(name, value, type, visible) {
  172. this.name = name;
  173. this.value = value;
  174. this.type = type || null;
  175. this.visible = visible !== false;
  176. }
  177. };
  178. nnabla.Value = class {
  179. constructor(name, type, initializer) {
  180. this.name = name;
  181. this.type = !type && initializer && initializer.type ? initializer.type : type;
  182. this.initializer = initializer || null;
  183. }
  184. };
  185. nnabla.Node = class {
  186. constructor(metadata, func, attributes, inputs, outputs) {
  187. this.name = func.name;
  188. this.type = metadata.type(func.type) || { name: func.type, type: func.type };
  189. this.attributes = attributes || [];
  190. this.outputs = outputs || [];
  191. this.chain = [];
  192. // "nonlinearity" does not match metadata type
  193. const get_nonlinearity = (name) => {
  194. switch (name) {
  195. case "identity": return "Identity";
  196. case "relu": return "ReLU";
  197. case "sigmoid": return "Sigmoid";
  198. case "tanh": return "Tanh";
  199. case "leaky_relu": return "LeakyReLU";
  200. case "elu": return "ELU";
  201. case "relu6": return "ReLU6";
  202. default: return name;
  203. }
  204. };
  205. switch (func.type) {
  206. case "FusedConvolution": {
  207. this.inputs = inputs.slice(0, 3) || [];
  208. if (inputs.length > 3) {
  209. this.chain.push(new nnabla.Node(metadata, { name: `${func.name}/bn`, type: "BatchNormalization" }, [], inputs.slice(3, 7)));
  210. }
  211. if (inputs.length > 7) {
  212. this.chain.push(new nnabla.Node(metadata, { name: `${func.name}/add`, type: "Add2" }, [], inputs.slice(7)));
  213. }
  214. const type_a = attributes.find((item) => item.name === "nonlinearity").value;
  215. this.chain.push(new nnabla.Node(metadata, { name: `${func.name}/act`, type: get_nonlinearity(type_a) }));
  216. break;
  217. }
  218. case "FusedBatchNormalization": {
  219. this.inputs = inputs.slice(0, 5) || [];
  220. if (inputs.length > 4) {
  221. this.chain.push(new nnabla.Node(metadata, { name: `${func.name}/add`, type: "Add2" }, [], inputs.slice(5)));
  222. }
  223. const type_b = attributes.find((item) => item.name === "nonlinearity").value;
  224. this.chain.push(new nnabla.Node(metadata, { name: `${func.name}/act`, type: get_nonlinearity(type_b) }));
  225. break;
  226. }
  227. default: {
  228. this.inputs = inputs || [];
  229. break;
  230. }
  231. }
  232. }
  233. };
  234. nnabla.Tensor = class {
  235. constructor(name, type, values) {
  236. this.name = name;
  237. this.type = type;
  238. this.encoding = '|';
  239. this.values = values;
  240. const dataType = this.type.dataType;
  241. switch (dataType) {
  242. case 'float32': this.values = new Float32Array(this.values); break;
  243. default: throw new nnabla.Error(`Unsupported data type '${dataType}'.`);
  244. }
  245. }
  246. };
  247. nnabla.TensorType = class {
  248. constructor(shape) {
  249. this.dataType = "float32";
  250. this.shape = shape;
  251. }
  252. toString() {
  253. return this.dataType + this.shape.toString();
  254. }
  255. };
  256. nnabla.TensorShape = class {
  257. constructor(dimensions) {
  258. this.dimensions = dimensions.map((dim) => typeof dim === 'bigint' ? dim.toNumber() : dim);
  259. }
  260. toString() {
  261. if (Array.isArray(this.dimensions) && this.dimensions.length > 0) {
  262. return `[${this.dimensions.join(',')}]`;
  263. }
  264. return '';
  265. }
  266. };
  267. nnabla.Error = class extends Error {
  268. constructor(message) {
  269. super(message);
  270. this.name = 'Error loading Neural Network Library model.';
  271. }
  272. };
  273. export const ModelFactory = nnabla.ModelFactory;