bigdl.js 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. // Experimental
  2. var bigdl = bigdl || {};
  3. var protobuf = protobuf || require('./protobuf');
  4. bigdl.ModelFactory = class {
  5. match(context) {
  6. const tags = context.tags('pb');
  7. if (tags.has(2) && tags.has(7) && tags.has(8) && tags.has(9) && tags.has(10) && tags.has(11) && tags.has(12)) {
  8. return 'bigdl';
  9. }
  10. return '';
  11. }
  12. open(context) {
  13. return context.require('./bigdl-proto').then(() => {
  14. let module = null;
  15. try {
  16. // https://github.com/intel-analytics/BigDL/blob/master/spark/dl/src/main/resources/serialization/bigdl.proto
  17. bigdl.proto = protobuf.get('bigdl').com.intel.analytics.bigdl.serialization;
  18. const stream = context.stream;
  19. const reader = protobuf.BinaryReader.open(stream);
  20. module = bigdl.proto.BigDLModule.decode(reader);
  21. }
  22. catch (error) {
  23. const message = error && error.message ? error.message : error.toString();
  24. throw new bigdl.Error('File format is not bigdl.BigDLModule (' + message.replace(/\.$/, '') + ').');
  25. }
  26. return context.metadata('bigdl-metadata.json').then((metadata) => {
  27. return new bigdl.Model(metadata, module);
  28. });
  29. });
  30. }
  31. };
  32. bigdl.Model = class {
  33. constructor(metadata, module) {
  34. this._version = module && module.version ? module.version : '';
  35. this._graphs = [ new bigdl.Graph(metadata, module) ];
  36. }
  37. get format() {
  38. return 'BigDL' + (this._version ? ' v' + this._version : '');
  39. }
  40. get graphs() {
  41. return this._graphs;
  42. }
  43. };
  44. bigdl.Graph = class {
  45. constructor(metadata, module) {
  46. this._type = module.moduleType;
  47. this._inputs = [];
  48. this._outputs = [];
  49. this._nodes = [];
  50. this._loadModule(metadata, module);
  51. }
  52. _loadModule(metadata, module) {
  53. switch (module.moduleType) {
  54. case 'com.intel.analytics.bigdl.nn.StaticGraph':
  55. case 'com.intel.analytics.bigdl.nn.Sequential': {
  56. for (const submodule of module.subModules) {
  57. this._loadModule(metadata, submodule);
  58. }
  59. break;
  60. }
  61. case 'com.intel.analytics.bigdl.nn.Input': {
  62. this._inputs.push(new bigdl.Parameter(module.name, [
  63. new bigdl.Argument(module.name)
  64. ]));
  65. break;
  66. }
  67. default: {
  68. this._nodes.push(new bigdl.Node(metadata, module));
  69. break;
  70. }
  71. }
  72. }
  73. get type() {
  74. return this._type;
  75. }
  76. get inputs() {
  77. return this._inputs;
  78. }
  79. get outputs() {
  80. return this._outputs;
  81. }
  82. get nodes() {
  83. return this._nodes;
  84. }
  85. };
  86. bigdl.Parameter = class {
  87. constructor(name, args) {
  88. this._name = name;
  89. this._arguments = args;
  90. }
  91. get name() {
  92. return this._name;
  93. }
  94. get visible() {
  95. return true;
  96. }
  97. get arguments() {
  98. return this._arguments;
  99. }
  100. };
  101. bigdl.Argument = class {
  102. constructor(name, type, initializer) {
  103. if (typeof name !== 'string') {
  104. throw new bigdl.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  105. }
  106. this._name = name;
  107. this._type = type || null;
  108. this._initializer = initializer || null;
  109. }
  110. get name() {
  111. return this._name;
  112. }
  113. get type() {
  114. if (this._initializer) {
  115. return this._initializer.type;
  116. }
  117. return this._type;
  118. }
  119. get initializer() {
  120. return this._initializer;
  121. }
  122. };
  123. bigdl.Node = class {
  124. constructor(metadata, module) {
  125. const type = module.moduleType;
  126. this._name = module.name;
  127. this._attributes = [];
  128. this._inputs = [];
  129. this._outputs = [];
  130. this._inputs.push(new bigdl.Parameter('input', module.preModules.map((id) => new bigdl.Argument(id, null, null))));
  131. this._type = metadata.type(type) || { name: type };
  132. const inputs = (this._type && this._type.inputs) ? this._type.inputs.slice() : [];
  133. inputs.shift();
  134. if (module.weight) {
  135. inputs.shift();
  136. this._inputs.push(new bigdl.Parameter('weight', [
  137. new bigdl.Argument('', null, new bigdl.Tensor(module.weight))
  138. ]));
  139. }
  140. if (module.bias) {
  141. inputs.shift();
  142. this._inputs.push(new bigdl.Parameter('bias', [
  143. new bigdl.Argument('', null, new bigdl.Tensor(module.bias))
  144. ]));
  145. }
  146. if (module.parameters && module.parameters.length > 0) {
  147. for (const parameter of module.parameters) {
  148. const input = inputs.shift();
  149. const inputName = input ? input.name : this._inputs.length.toString();
  150. this._inputs.push(new bigdl.Parameter(inputName, [
  151. new bigdl.Argument('', null, new bigdl.Tensor(parameter))
  152. ]));
  153. }
  154. }
  155. for (const key of Object.keys(module.attr)) {
  156. const value = module.attr[key];
  157. if (key === 'module_numerics' || key === 'module_tags') {
  158. continue;
  159. }
  160. if (value.dataType === bigdl.proto.DataType.TENSOR) {
  161. if (value.value) {
  162. this._inputs.push(new bigdl.Parameter(key, [ new bigdl.Argument('', null, new bigdl.Tensor(value.tensorValue)) ]));
  163. }
  164. continue;
  165. }
  166. if (value.dataType === bigdl.proto.DataType.REGULARIZER && value.value === undefined) {
  167. continue;
  168. }
  169. if (value.dataType === bigdl.proto.DataType.ARRAY_VALUE && value.arrayValue.datatype === bigdl.proto.DataType.TENSOR) {
  170. this._inputs.push(new bigdl.Parameter(key, value.arrayValue.tensor.map((tensor) => new bigdl.Argument('', null, new bigdl.Tensor(tensor)))));
  171. continue;
  172. }
  173. this._attributes.push(new bigdl.Attribute(key, value));
  174. }
  175. const output = this._name || this._type + module.namePostfix;
  176. this._outputs.push(new bigdl.Parameter('output', [
  177. new bigdl.Argument(output, null, null)
  178. ]));
  179. }
  180. get type() {
  181. return this._type;
  182. }
  183. get name() {
  184. return this._name;
  185. }
  186. get inputs() {
  187. return this._inputs;
  188. }
  189. get outputs() {
  190. return this._outputs;
  191. }
  192. get attributes() {
  193. return this._attributes;
  194. }
  195. };
  196. bigdl.Attribute = class {
  197. constructor(name, value) {
  198. this._name = name;
  199. switch (value.dataType) {
  200. case bigdl.proto.DataType.INT32: {
  201. this._type = 'int32';
  202. this._value = value.int32Value;
  203. break;
  204. }
  205. case bigdl.proto.DataType.FLOAT: {
  206. this._type = 'float32';
  207. this._value = value.floatValue;
  208. break;
  209. }
  210. case bigdl.proto.DataType.DOUBLE: {
  211. this._type = 'float64';
  212. this._value = value.doubleValue;
  213. break;
  214. }
  215. case bigdl.proto.DataType.BOOL: {
  216. this._type = 'boolean';
  217. this._value = value.boolValue;
  218. break;
  219. }
  220. case bigdl.proto.DataType.REGULARIZER: {
  221. this._value = value.value;
  222. break;
  223. }
  224. case bigdl.proto.DataType.MODULE: {
  225. this._value = value.bigDLModule;
  226. break;
  227. }
  228. case bigdl.proto.DataType.NAME_ATTR_LIST: {
  229. this._value = value.nameAttrListValue;
  230. break;
  231. }
  232. case bigdl.proto.DataType.ARRAY_VALUE: {
  233. switch (value.arrayValue.datatype) {
  234. case bigdl.proto.DataType.INT32: {
  235. this._type = 'int32[]';
  236. this._value = value.arrayValue.i32;
  237. break;
  238. }
  239. case bigdl.proto.DataType.FLOAT: {
  240. this._type = 'float32[]';
  241. this._value = value.arrayValue.flt;
  242. break;
  243. }
  244. case bigdl.proto.DataType.STRING: {
  245. this._type = 'string[]';
  246. this._value = value.arrayValue.str;
  247. break;
  248. }
  249. case bigdl.proto.DataType.TENSOR: {
  250. this._type = 'tensor[]';
  251. this._value = value.arrayValue.tensor;
  252. break;
  253. }
  254. default: {
  255. throw new bigdl.Error("Unsupported attribute array data type '" + value.arrayValue.datatype + "'.");
  256. }
  257. }
  258. break;
  259. }
  260. case bigdl.proto.DataType.DATA_FORMAT: {
  261. this._dataType = 'InputDataFormat';
  262. switch (value.dataFormatValue) {
  263. case 0: this._value = 'NCHW'; break;
  264. case 1: this._value = 'NHWC'; break;
  265. default: throw new bigdl.Error("Unsupported data format '" + value.dataFormatValue + "'.");
  266. }
  267. break;
  268. }
  269. default: {
  270. throw new bigdl.Error("Unsupported attribute data type '" + value.dataType + "'.");
  271. }
  272. }
  273. }
  274. get type() {
  275. return this._type;
  276. }
  277. get name() {
  278. return this._name;
  279. }
  280. get value() {
  281. return this._value;
  282. }
  283. get visible() {
  284. return true;
  285. }
  286. };
  287. bigdl.Tensor = class {
  288. constructor(tensor) {
  289. this._type = new bigdl.TensorType(tensor.datatype, new bigdl.TensorShape(tensor.size));
  290. }
  291. get kind() {
  292. return 'Parameter';
  293. }
  294. get type() {
  295. return this._type;
  296. }
  297. get state() {
  298. return 'Tensor data not implemented.';
  299. }
  300. get value() {
  301. return null;
  302. }
  303. toString() {
  304. return '';
  305. }
  306. };
  307. bigdl.TensorType = class {
  308. constructor(dataType, shape) {
  309. switch (dataType) {
  310. case bigdl.proto.DataType.FLOAT: this._dataType = 'float32'; break;
  311. case bigdl.proto.DataType.DOUBLE: this._dataType = 'float64'; break;
  312. default: throw new bigdl.Error("Unsupported tensor type '" + dataType + "'.");
  313. }
  314. this._shape = shape;
  315. }
  316. get dataType() {
  317. return this._dataType;
  318. }
  319. get shape() {
  320. return this._shape;
  321. }
  322. toString() {
  323. return (this.dataType || '?') + this._shape.toString();
  324. }
  325. };
  326. bigdl.TensorShape = class {
  327. constructor(dimensions) {
  328. this._dimensions = dimensions;
  329. if (!dimensions.every((dimension) => Number.isInteger(dimension))) {
  330. throw new bigdl.Error("Invalid tensor shape '" + JSON.stringify(dimensions) + "'.");
  331. }
  332. }
  333. get dimensions() {
  334. return this._dimensions;
  335. }
  336. toString() {
  337. return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
  338. }
  339. };
  340. bigdl.Error = class extends Error {
  341. constructor(message) {
  342. super(message);
  343. this.name = 'Error loading BigDL model.';
  344. }
  345. };
  346. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  347. module.exports.ModelFactory = bigdl.ModelFactory;
  348. }