nnabla.js 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. var nnabla = nnabla || {};
  2. var protobuf = protobuf || require('./protobuf');
  3. var text = text || require('./text');
  4. nnabla.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier;
  7. if (identifier.endsWith('.nntxt')) {
  8. const tags = context.tags('pbtxt');
  9. if (tags.has('network')) {
  10. return 'nnabla.pbtxt';
  11. }
  12. }
  13. return undefined;
  14. }
  15. open(context, match) {
  16. return context.require('./nnabla-proto').then(() => {
  17. nnabla.proto = protobuf.get('nnabla').nnabla;
  18. switch (match) {
  19. case 'nnabla.pbtxt': {
  20. const stream = context.stream;
  21. const reader = protobuf.TextReader.open(stream);
  22. const model = nnabla.proto.NNablaProtoBuf.decodeText(reader);
  23. const promises = [
  24. context.request('nnp_version.txt', null),
  25. context.request('parameter.protobuf', null)
  26. ];
  27. const open = (model, version) => {
  28. return context.metadata('nnabla-metadata.json').then((metadata) => {
  29. return new nnabla.Model(metadata, model, 'NNabla' + (version ? ' v' + version : ''));
  30. });
  31. };
  32. return Promise.all(promises).then((streams) => {
  33. const version = text.Reader.open(streams[0]).read();
  34. const reader = protobuf.BinaryReader.open(streams[1]);
  35. const params = nnabla.proto.NNablaProtoBuf.decode(reader);
  36. model.parameter = params.parameter;
  37. return open(model, version);
  38. }).catch(() => {
  39. return open(model);
  40. });
  41. }
  42. default: {
  43. throw new nnabla.Error("Unsupported nnabla format '" + match + "'.");
  44. }
  45. }
  46. });
  47. }
  48. };
  49. nnabla.Model = class {
  50. constructor(metadata, model, format) {
  51. this._format = format;
  52. this._graphs = [ new nnabla.Graph(metadata, model) ];
  53. }
  54. get format(){
  55. return this._format;
  56. }
  57. get graphs() {
  58. return this._graphs;
  59. }
  60. };
  61. nnabla.Graph = class {
  62. constructor (metadata, model) {
  63. const executor = model.executor[0]; // TODO: Multiple executors?
  64. const network_name = executor.network_name;
  65. const network = model.network.find((item) => item.name === network_name);
  66. this._dataTypes = new Map(network.variable.map((item) => {
  67. const shape = new nnabla.TensorShape(item.shape.dim);
  68. const type = new nnabla.TensorType(item.type, shape);
  69. return [ item.name, type ];
  70. }));
  71. this._tensors = new Map(model.parameter.map((item) => {
  72. const name = item.variable_name;
  73. return [ name, new nnabla.Tensor(name, this.dataType(name), item.data) ];
  74. }));
  75. this._arguments = new Map();
  76. this._inputs = executor.data_variable.map((item) => {
  77. const name = item.variable_name;
  78. return new nnabla.Parameter(name, [ this.argument(name) ]);
  79. });
  80. this._outputs = executor.output_variable.map((item) => {
  81. const name = item.variable_name;
  82. return new nnabla.Parameter(name, [ this.argument(name) ]);
  83. });
  84. const get_parameters = (func) => {
  85. for (const [key, value] of Object.entries(func)) {
  86. if (key.endsWith("_param")) {
  87. return value;
  88. }
  89. }
  90. return undefined;
  91. };
  92. this._nodes = network.function.map((func) => {
  93. const parameters = get_parameters(func) || [];
  94. const attributes = Object.entries(parameters).map(([name, value]) => {
  95. return new nnabla.Attribute(metadata, func.type, name, value);
  96. });
  97. const func_type = metadata.type(func.type);
  98. const inputs = [];
  99. for (let index = 0; index < func.input.length; ) {
  100. const input = func_type.inputs && index < func_type.inputs.length ? func_type.inputs[index] : { name: index.toString() };
  101. const count = input.list ? func.input.length - index : 1;
  102. const args = func.input.slice(index, index + count).map((input) => this.argument(input));
  103. inputs.push(new nnabla.Parameter(input.name, args));
  104. index += count;
  105. }
  106. const outputs = [];
  107. for (let index = 0; index < func.output.length; ) {
  108. const output = func_type.outputs && index < func_type.outputs.length ? func_type.outputs[index] : { name: index.toString() };
  109. const count = output.list ? func.output.length - index : 1;
  110. const args = func.output.slice(index, index + count).map((output) => this.argument(output));
  111. outputs.push(new nnabla.Parameter(output.name, args));
  112. index += count;
  113. }
  114. return new nnabla.Node(metadata, func, attributes, inputs, outputs);
  115. });
  116. }
  117. get nodes() {
  118. return this._nodes;
  119. }
  120. get inputs() {
  121. return this._inputs;
  122. }
  123. get outputs() {
  124. return this._outputs;
  125. }
  126. dataType(name) {
  127. return this._dataTypes.get(name);
  128. }
  129. tensor(name) {
  130. return this._tensors.get(name);
  131. }
  132. argument(name) {
  133. if (!this._arguments.has(name)) {
  134. this._arguments.set(name, new nnabla.Argument(name, this.dataType(name), this.tensor(name)));
  135. }
  136. return this._arguments.get(name);
  137. }
  138. };
  139. nnabla.Parameter = class {
  140. constructor(name, args) {
  141. this._name = name;
  142. this._arguments = args;
  143. }
  144. get name() {
  145. return this._name;
  146. }
  147. get visible() {
  148. return true;
  149. }
  150. get arguments() {
  151. return this._arguments;
  152. }
  153. };
  154. nnabla.Argument = class {
  155. constructor(name, type, initializer) {
  156. this._name = name;
  157. this._type = type || null;
  158. this._initializer = initializer || null;
  159. }
  160. get name() {
  161. return this._name;
  162. }
  163. get type() {
  164. if (this._type) {
  165. return this._type;
  166. }
  167. if (this._initializer) {
  168. return this._initializer.type;
  169. }
  170. return null;
  171. }
  172. get initializer() {
  173. return this._initializer;
  174. }
  175. };
  176. nnabla.Node = class {
  177. constructor(metadata, func, attributes, inputs, outputs) {
  178. this._name = func.name;
  179. this._type = metadata.type(func.type) || { name: func.type, type: func.type };
  180. this._attributes = attributes || [];
  181. this._outputs = outputs || [];
  182. this._chain = [];
  183. // TODO: "nonlinearity" does not match metadata type
  184. const get_nonlinearity = (name) => {
  185. switch (name) {
  186. case "identity": return "Identity";
  187. case "relu": return "ReLU";
  188. case "sigmoid": return "Sigmoid";
  189. case "tanh": return "Tanh";
  190. case "leaky_relu": return "LeakyReLU";
  191. case "elu": return "ELU";
  192. case "relu6": return "ReLU6";
  193. default: return name;
  194. }
  195. };
  196. switch (func.type) {
  197. case "FusedConvolution": {
  198. this._inputs = inputs.slice(0, 3) || [];
  199. if (inputs.length > 3) {
  200. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/bn", type: "BatchNormalization" }, [], inputs.slice(3, 7)));
  201. }
  202. if (inputs.length > 7) {
  203. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/add", type: "Add2" }, [], inputs.slice(7)));
  204. }
  205. const type_a = attributes.find((item) => item.name === "nonlinearity").value;
  206. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/act", type: get_nonlinearity(type_a) }));
  207. break;
  208. }
  209. case "FusedBatchNormalization": {
  210. this._inputs = inputs.slice(0, 5) || [];
  211. if (inputs.length > 4) {
  212. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/add", type: "Add2" }, [], inputs.slice(5)));
  213. }
  214. const type_b = attributes.find((item) => item.name === "nonlinearity").value;
  215. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/act", type: get_nonlinearity(type_b) }));
  216. break;
  217. }
  218. default: {
  219. this._inputs = inputs || [];
  220. break;
  221. }
  222. }
  223. }
  224. get name() {
  225. return this._name;
  226. }
  227. get type() {
  228. return this._type;
  229. }
  230. get attributes() {
  231. return this._attributes;
  232. }
  233. get inputs() {
  234. return this._inputs;
  235. }
  236. get outputs() {
  237. return this._outputs;
  238. }
  239. get chain() {
  240. return this._chain;
  241. }
  242. };
  243. nnabla.Attribute = class {
  244. constructor(metadata, type, name, value) {
  245. this._name = name;
  246. const attribute = metadata.attribute(type, name);
  247. this._description = attribute.description;
  248. switch (attribute.type) {
  249. case "Shape":
  250. this._type = "int64[]";
  251. this._value = value.dim;
  252. break;
  253. default:
  254. this._type = attribute.type;
  255. this._value = value;
  256. break;
  257. }
  258. if (Object.prototype.hasOwnProperty.call(attribute, 'default') && this._value == attribute.default) {
  259. this._visible = false;
  260. }
  261. }
  262. get name() {
  263. return this._name;
  264. }
  265. get description() {
  266. return this._description;
  267. }
  268. get type() {
  269. return this._type;
  270. }
  271. get value() {
  272. return this._value;
  273. }
  274. get visible() {
  275. return this._visible == false ? false : true;
  276. }
  277. };
  278. nnabla.Tensor = class {
  279. constructor(name, type, values) {
  280. this._name = name;
  281. this._type = type;
  282. this._values = values;
  283. }
  284. get kind() {
  285. return this._kind;
  286. }
  287. get name() {
  288. return this._name;
  289. }
  290. get type() {
  291. return this._type;
  292. }
  293. get state() {
  294. return this._context().state;
  295. }
  296. get value() {
  297. const context = this._context();
  298. if (context.state) {
  299. return null;
  300. }
  301. context.limit = Number.MAX_SAFE_INTEGER;
  302. return this._decode(context, 0);
  303. }
  304. toString() {
  305. const context = this._context();
  306. if (context.state) {
  307. return '';
  308. }
  309. context.limit = 10000;
  310. const value = this._decode(context, 0);
  311. return JSON.stringify(value, null, 4);
  312. }
  313. _context() {
  314. const context = {};
  315. context.state = null;
  316. context.index = 0;
  317. context.count = 0;
  318. if (!this._values) {
  319. context.state = 'Tensor data is empty.';
  320. return context;
  321. }
  322. switch (this._type.dataType) {
  323. case 'float32':
  324. context.data = new Float32Array(this._values);
  325. break;
  326. default:
  327. context.state = 'Unknown data type.';
  328. return context;
  329. }
  330. context.shape = this._type.shape.dimensions;
  331. context.dataType = this._type.dataType;
  332. return context;
  333. }
  334. _decode(context, dimension) {
  335. const results = [];
  336. const size = context.shape[dimension];
  337. if (dimension === context.shape.length - 1) {
  338. for (let i = 0; i < size; i++) {
  339. if (context.count > context.limit) {
  340. results.push('...');
  341. return results;
  342. }
  343. switch (context.dataType) {
  344. case 'float32':
  345. results.push(context.data[context.index]);
  346. break;
  347. default:
  348. context.state = 'Unknown data type.';
  349. break;
  350. }
  351. context.index++;
  352. context.count++;
  353. }
  354. }
  355. else {
  356. for (let j = 0; j < size; j++) {
  357. if (context.count > context.limit) {
  358. results.push('...');
  359. return results;
  360. }
  361. results.push(this._decode(context, dimension + 1));
  362. }
  363. }
  364. return results;
  365. }
  366. };
  367. nnabla.TensorType = class {
  368. constructor(dataType, shape) {
  369. this._dataType = "float32";
  370. this._shape = shape;
  371. this._denotation = null; // TODO
  372. }
  373. get dataType() {
  374. return this._dataType;
  375. }
  376. get shape() {
  377. return this._shape;
  378. }
  379. get denotation() {
  380. return this._denotation;
  381. }
  382. toString() {
  383. return this._dataType + this._shape.toString();
  384. }
  385. };
  386. nnabla.TensorShape = class {
  387. constructor(dimensions) {
  388. this._dimensions = dimensions;
  389. }
  390. get dimensions() {
  391. return this._dimensions;
  392. }
  393. toString() {
  394. return (this._dimensions && this._dimensions.length) ? ('[' + this._dimensions.join(',') + ']') : '';
  395. }
  396. };
  397. nnabla.Error = class extends Error {
  398. constructor(message) {
  399. super(message);
  400. this.name = 'Error loading Neural Network Library model.';
  401. }
  402. };
  403. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  404. module.exports.ModelFactory = nnabla.ModelFactory;
  405. }