nnabla.js 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. var nnabla = {};
  2. var protobuf = require('./protobuf');
  3. var text = require('./text');
  4. nnabla.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier;
  7. if (identifier.endsWith('.nntxt')) {
  8. const tags = context.tags('pbtxt');
  9. if (tags.has('network')) {
  10. return 'nnabla.pbtxt';
  11. }
  12. }
  13. return undefined;
  14. }
  15. open(context, match) {
  16. return context.require('./nnabla-proto').then(() => {
  17. nnabla.proto = protobuf.get('nnabla').nnabla;
  18. switch (match) {
  19. case 'nnabla.pbtxt': {
  20. const stream = context.stream;
  21. const reader = protobuf.TextReader.open(stream);
  22. const model = nnabla.proto.NNablaProtoBuf.decodeText(reader);
  23. const promises = [
  24. context.request('nnp_version.txt', null),
  25. context.request('parameter.protobuf', null)
  26. ];
  27. const open = (model, version) => {
  28. return context.metadata('nnabla-metadata.json').then((metadata) => {
  29. return new nnabla.Model(metadata, model, 'NNabla' + (version ? ' v' + version : ''));
  30. });
  31. };
  32. return Promise.all(promises).then((streams) => {
  33. const version = text.Reader.open(streams[0]).read();
  34. const reader = protobuf.BinaryReader.open(streams[1]);
  35. const params = nnabla.proto.NNablaProtoBuf.decode(reader);
  36. model.parameter = params.parameter;
  37. return open(model, version);
  38. }).catch(() => {
  39. return open(model);
  40. });
  41. }
  42. default: {
  43. throw new nnabla.Error("Unsupported nnabla format '" + match + "'.");
  44. }
  45. }
  46. });
  47. }
  48. };
  49. nnabla.Model = class {
  50. constructor(metadata, model, format) {
  51. this._format = format;
  52. this._graphs = [ new nnabla.Graph(metadata, model) ];
  53. }
  54. get format() {
  55. return this._format;
  56. }
  57. get graphs() {
  58. return this._graphs;
  59. }
  60. };
  61. nnabla.Graph = class {
  62. constructor (metadata, model) {
  63. const executor = model.executor[0]; // TODO: Multiple executors?
  64. const network_name = executor.network_name;
  65. const network = model.network.find((item) => item.name === network_name);
  66. const dataTypes = new Map(network.variable.map((item) => {
  67. const shape = new nnabla.TensorShape(item.shape.dim);
  68. const type = new nnabla.TensorType(item.type, shape);
  69. return [ item.name, type ];
  70. }));
  71. const tensors = new Map(model.parameter.map((item) => {
  72. const name = item.variable_name;
  73. return [ name, new nnabla.Tensor(name, dataTypes.get(name), item.data) ];
  74. }));
  75. const args = new Map();
  76. const arg = (name) => {
  77. if (!args.has(name)) {
  78. args.set(name, new nnabla.Argument(name, dataTypes.get(name), tensors.get(name)));
  79. }
  80. return args.get(name);
  81. };
  82. this._inputs = executor.data_variable.map((item) => {
  83. const name = item.variable_name;
  84. return new nnabla.Parameter(name, [ arg(name) ]);
  85. });
  86. this._outputs = executor.output_variable.map((item) => {
  87. const name = item.variable_name;
  88. return new nnabla.Parameter(name, [ arg(name) ]);
  89. });
  90. const get_parameters = (func) => {
  91. for (const [key, value] of Object.entries(func)) {
  92. if (key.endsWith("_param")) {
  93. return value;
  94. }
  95. }
  96. return undefined;
  97. };
  98. this._nodes = network.function.map((func) => {
  99. const parameters = get_parameters(func) || [];
  100. const attributes = Object.entries(parameters).map(([name, value]) => {
  101. return new nnabla.Attribute(metadata, func.type, name, value);
  102. });
  103. const func_type = metadata.type(func.type);
  104. const inputs = [];
  105. for (let index = 0; index < func.input.length;) {
  106. const input = func_type.inputs && index < func_type.inputs.length ? func_type.inputs[index] : { name: index.toString() };
  107. const count = input.list ? func.input.length - index : 1;
  108. const args = func.input.slice(index, index + count).map((input) => arg(input));
  109. inputs.push(new nnabla.Parameter(input.name, args));
  110. index += count;
  111. }
  112. const outputs = [];
  113. for (let index = 0; index < func.output.length;) {
  114. const output = func_type.outputs && index < func_type.outputs.length ? func_type.outputs[index] : { name: index.toString() };
  115. const count = output.list ? func.output.length - index : 1;
  116. const args = func.output.slice(index, index + count).map((output) => arg(output));
  117. outputs.push(new nnabla.Parameter(output.name, args));
  118. index += count;
  119. }
  120. return new nnabla.Node(metadata, func, attributes, inputs, outputs);
  121. });
  122. }
  123. get nodes() {
  124. return this._nodes;
  125. }
  126. get inputs() {
  127. return this._inputs;
  128. }
  129. get outputs() {
  130. return this._outputs;
  131. }
  132. };
  133. nnabla.Parameter = class {
  134. constructor(name, args) {
  135. this._name = name;
  136. this._arguments = args;
  137. }
  138. get name() {
  139. return this._name;
  140. }
  141. get visible() {
  142. return true;
  143. }
  144. get arguments() {
  145. return this._arguments;
  146. }
  147. };
  148. nnabla.Argument = class {
  149. constructor(name, type, initializer) {
  150. this._name = name;
  151. this._type = type || null;
  152. this._initializer = initializer || null;
  153. }
  154. get name() {
  155. return this._name;
  156. }
  157. get type() {
  158. if (this._type) {
  159. return this._type;
  160. }
  161. if (this._initializer) {
  162. return this._initializer.type;
  163. }
  164. return null;
  165. }
  166. get initializer() {
  167. return this._initializer;
  168. }
  169. };
  170. nnabla.Node = class {
  171. constructor(metadata, func, attributes, inputs, outputs) {
  172. this._name = func.name;
  173. this._type = metadata.type(func.type) || { name: func.type, type: func.type };
  174. this._attributes = attributes || [];
  175. this._outputs = outputs || [];
  176. this._chain = [];
  177. // TODO: "nonlinearity" does not match metadata type
  178. const get_nonlinearity = (name) => {
  179. switch (name) {
  180. case "identity": return "Identity";
  181. case "relu": return "ReLU";
  182. case "sigmoid": return "Sigmoid";
  183. case "tanh": return "Tanh";
  184. case "leaky_relu": return "LeakyReLU";
  185. case "elu": return "ELU";
  186. case "relu6": return "ReLU6";
  187. default: return name;
  188. }
  189. };
  190. switch (func.type) {
  191. case "FusedConvolution": {
  192. this._inputs = inputs.slice(0, 3) || [];
  193. if (inputs.length > 3) {
  194. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/bn", type: "BatchNormalization" }, [], inputs.slice(3, 7)));
  195. }
  196. if (inputs.length > 7) {
  197. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/add", type: "Add2" }, [], inputs.slice(7)));
  198. }
  199. const type_a = attributes.find((item) => item.name === "nonlinearity").value;
  200. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/act", type: get_nonlinearity(type_a) }));
  201. break;
  202. }
  203. case "FusedBatchNormalization": {
  204. this._inputs = inputs.slice(0, 5) || [];
  205. if (inputs.length > 4) {
  206. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/add", type: "Add2" }, [], inputs.slice(5)));
  207. }
  208. const type_b = attributes.find((item) => item.name === "nonlinearity").value;
  209. this._chain.push(new nnabla.Node(metadata, { name: func.name + "/act", type: get_nonlinearity(type_b) }));
  210. break;
  211. }
  212. default: {
  213. this._inputs = inputs || [];
  214. break;
  215. }
  216. }
  217. }
  218. get name() {
  219. return this._name;
  220. }
  221. get type() {
  222. return this._type;
  223. }
  224. get attributes() {
  225. return this._attributes;
  226. }
  227. get inputs() {
  228. return this._inputs;
  229. }
  230. get outputs() {
  231. return this._outputs;
  232. }
  233. get chain() {
  234. return this._chain;
  235. }
  236. };
  237. nnabla.Attribute = class {
  238. constructor(metadata, type, name, value) {
  239. this._name = name;
  240. const attribute = metadata.attribute(type, name);
  241. this._description = attribute.description;
  242. switch (attribute.type) {
  243. case "shape":
  244. this._type = "int64[]";
  245. this._value = value.dim;
  246. break;
  247. default:
  248. this._type = attribute.type;
  249. this._value = value;
  250. break;
  251. }
  252. if (Object.prototype.hasOwnProperty.call(attribute, 'default') && this._value == attribute.default) {
  253. this._visible = false;
  254. }
  255. }
  256. get name() {
  257. return this._name;
  258. }
  259. get description() {
  260. return this._description;
  261. }
  262. get type() {
  263. return this._type;
  264. }
  265. get value() {
  266. return this._value;
  267. }
  268. get visible() {
  269. return this._visible == false ? false : true;
  270. }
  271. };
  272. nnabla.Tensor = class {
  273. constructor(name, type, values) {
  274. this._name = name;
  275. this._type = type;
  276. this._values = values;
  277. }
  278. get name() {
  279. return this._name;
  280. }
  281. get type() {
  282. return this._type;
  283. }
  284. get layout() {
  285. return '|';
  286. }
  287. get values() {
  288. const dataType = this._type.dataType;
  289. switch (dataType) {
  290. case 'float32': return new Float32Array(this._values);
  291. default: throw new nnabla.Error("Unsupported data type '" + dataType + "'.");
  292. }
  293. }
  294. };
  295. nnabla.TensorType = class {
  296. constructor(dataType, shape) {
  297. this._dataType = "float32";
  298. this._shape = shape;
  299. this._denotation = null; // TODO
  300. }
  301. get dataType() {
  302. return this._dataType;
  303. }
  304. get shape() {
  305. return this._shape;
  306. }
  307. get denotation() {
  308. return this._denotation;
  309. }
  310. toString() {
  311. return this._dataType + this._shape.toString();
  312. }
  313. };
  314. nnabla.TensorShape = class {
  315. constructor(dimensions) {
  316. this._dimensions = dimensions;
  317. }
  318. get dimensions() {
  319. return this._dimensions;
  320. }
  321. toString() {
  322. return (this._dimensions && this._dimensions.length) ? ('[' + this._dimensions.join(',') + ']') : '';
  323. }
  324. };
  325. nnabla.Error = class extends Error {
  326. constructor(message) {
  327. super(message);
  328. this.name = 'Error loading Neural Network Library model.';
  329. }
  330. };
  331. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  332. module.exports.ModelFactory = nnabla.ModelFactory;
  333. }