bigdl.js 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. // Experimental
  2. const bigdl = {};
  3. bigdl.ModelFactory = class {
  4. match(context) {
  5. const tags = context.tags('pb');
  6. if (tags.has(2) && tags.has(7) && tags.has(8) &&
  7. tags.has(9) && tags.has(10) && tags.has(11) && tags.has(12)) {
  8. context.type = 'bigdl';
  9. }
  10. }
  11. async open(context) {
  12. bigdl.proto = await context.require('./bigdl-proto');
  13. bigdl.proto = bigdl.proto.com.intel.analytics.bigdl.serialization;
  14. let module = null;
  15. try {
  16. // https://github.com/intel-analytics/BigDL/blob/master/spark/dl/src/main/resources/serialization/bigdl.proto
  17. const reader = context.read('protobuf.binary');
  18. module = bigdl.proto.BigDLModule.decode(reader);
  19. } catch (error) {
  20. const message = error && error.message ? error.message : error.toString();
  21. throw new bigdl.Error(`File format is not bigdl.BigDLModule (${message.replace(/\.$/, '')}).`);
  22. }
  23. const metadata = await context.metadata('bigdl-metadata.json');
  24. return new bigdl.Model(metadata, module);
  25. }
  26. };
  27. bigdl.Model = class {
  28. constructor(metadata, module) {
  29. const version = module && module.version ? module.version : '';
  30. this.format = `BigDL${version ? ` v${version}` : ''}`;
  31. this.graphs = [ new bigdl.Graph(metadata, module) ];
  32. }
  33. };
  34. bigdl.Graph = class {
  35. constructor(metadata, module) {
  36. this.type = module.moduleType;
  37. this.inputs = [];
  38. this.outputs = [];
  39. this.nodes = [];
  40. const tensors = module.attr && module.attr.global_storage && module.attr.global_storage.nameAttrListValue && module.attr.global_storage.nameAttrListValue.attr ? module.attr.global_storage.nameAttrListValue.attr : {};
  41. const values = new Map();
  42. values.map = (name) => {
  43. if (!values.has(name)) {
  44. values.set(name, new bigdl.Value(name));
  45. }
  46. return values.get(name);
  47. };
  48. const loadModule = (metadata, module, tensors) => {
  49. switch (module.moduleType) {
  50. case 'com.intel.analytics.bigdl.nn.StaticGraph':
  51. case 'com.intel.analytics.bigdl.nn.Sequential': {
  52. for (const submodule of module.subModules) {
  53. loadModule(metadata, submodule, tensors);
  54. }
  55. break;
  56. }
  57. case 'com.intel.analytics.bigdl.nn.Input': {
  58. const argument = new bigdl.Argument(module.name, [ values.map(module.name) ]);
  59. this.inputs.push(argument);
  60. break;
  61. }
  62. default: {
  63. const node = new bigdl.Node(metadata, module, tensors, values);
  64. this.nodes.push(node);
  65. break;
  66. }
  67. }
  68. };
  69. loadModule(metadata, module, tensors);
  70. }
  71. };
  72. bigdl.Argument = class {
  73. constructor(name, value) {
  74. this.name = name;
  75. this.value = value;
  76. }
  77. };
  78. bigdl.Value = class {
  79. constructor(name, type, initializer) {
  80. if (typeof name !== 'string') {
  81. throw new bigdl.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  82. }
  83. this.name = name;
  84. this.type = type ? type : initializer ? initializer.type : null;
  85. this.initializer = initializer;
  86. }
  87. };
  88. bigdl.Node = class {
  89. constructor(metadata, module, tensors, values) {
  90. const type = module.moduleType;
  91. this.name = module.name;
  92. this.attributes = [];
  93. this.inputs = [];
  94. this.outputs = [];
  95. this.inputs.push(new bigdl.Argument('input', module.preModules.map((id) => values.map(id))));
  96. this.type = metadata.type(type) || { name: type };
  97. const inputs = this.type && this.type.inputs ? this.type.inputs.slice() : [];
  98. inputs.shift();
  99. if (module.weight) {
  100. inputs.shift();
  101. this.inputs.push(new bigdl.Argument('weight', [
  102. new bigdl.Value('', null, new bigdl.Tensor(module.weight, tensors))
  103. ]));
  104. }
  105. if (module.bias) {
  106. inputs.shift();
  107. this.inputs.push(new bigdl.Argument('bias', [
  108. new bigdl.Value('', null, new bigdl.Tensor(module.bias, tensors))
  109. ]));
  110. }
  111. if (module.parameters && module.parameters.length > 0) {
  112. for (const parameter of module.parameters) {
  113. const input = inputs.shift();
  114. const inputName = input ? input.name : this.inputs.length.toString();
  115. this.inputs.push(new bigdl.Argument(inputName, [
  116. new bigdl.Value('', null, new bigdl.Tensor(parameter, tensors))
  117. ]));
  118. }
  119. }
  120. for (const [key, value] of Object.entries(module.attr)) {
  121. if (key === 'module_numerics' || key === 'module_tags') {
  122. continue;
  123. }
  124. if (value.dataType === bigdl.proto.DataType.TENSOR) {
  125. if (value.value) {
  126. this.inputs.push(new bigdl.Argument(key, [ new bigdl.Value('', null, new bigdl.Tensor(value.tensorValue, tensors)) ]));
  127. }
  128. continue;
  129. }
  130. if (value.dataType === bigdl.proto.DataType.REGULARIZER && value.value === undefined) {
  131. continue;
  132. }
  133. if (value.dataType === bigdl.proto.DataType.ARRAY_VALUE && value.arrayValue.datatype === bigdl.proto.DataType.TENSOR) {
  134. this.inputs.push(new bigdl.Argument(key, value.arrayValue.tensor.map((tensor) => new bigdl.Value('', null, new bigdl.Tensor(tensor, tensors)))));
  135. continue;
  136. }
  137. this.attributes.push(new bigdl.Attribute(key, value));
  138. }
  139. const output = this.name || this.type + module.namePostfix;
  140. this.outputs.push(new bigdl.Argument('output', [ values.map(output) ]));
  141. }
  142. };
  143. bigdl.Attribute = class {
  144. constructor(name, value) {
  145. this.name = name;
  146. switch (value.dataType) {
  147. case bigdl.proto.DataType.INT32: {
  148. this.type = 'int32';
  149. this.value = value.int32Value;
  150. break;
  151. }
  152. case bigdl.proto.DataType.FLOAT: {
  153. this.type = 'float32';
  154. this.value = value.floatValue;
  155. break;
  156. }
  157. case bigdl.proto.DataType.DOUBLE: {
  158. this.type = 'float64';
  159. this.value = value.doubleValue;
  160. break;
  161. }
  162. case bigdl.proto.DataType.BOOL: {
  163. this.type = 'boolean';
  164. this.value = value.boolValue;
  165. break;
  166. }
  167. case bigdl.proto.DataType.REGULARIZER: {
  168. this.value = value.value;
  169. break;
  170. }
  171. case bigdl.proto.DataType.MODULE: {
  172. this.value = value.bigDLModule;
  173. break;
  174. }
  175. case bigdl.proto.DataType.NAME_ATTR_LIST: {
  176. this.value = value.nameAttrListValue;
  177. break;
  178. }
  179. case bigdl.proto.DataType.ARRAY_VALUE: {
  180. switch (value.arrayValue.datatype) {
  181. case bigdl.proto.DataType.INT32: {
  182. this.type = 'int32[]';
  183. this.value = value.arrayValue.i32;
  184. break;
  185. }
  186. case bigdl.proto.DataType.FLOAT: {
  187. this.type = 'float32[]';
  188. this.value = value.arrayValue.flt;
  189. break;
  190. }
  191. case bigdl.proto.DataType.STRING: {
  192. this.type = 'string[]';
  193. this.value = value.arrayValue.str;
  194. break;
  195. }
  196. case bigdl.proto.DataType.TENSOR: {
  197. this.type = 'tensor[]';
  198. this.value = value.arrayValue.tensor;
  199. break;
  200. }
  201. default: {
  202. throw new bigdl.Error(`Unsupported attribute array data type '${value.arrayValue.datatype}'.`);
  203. }
  204. }
  205. break;
  206. }
  207. case bigdl.proto.DataType.DATA_FORMAT: {
  208. switch (value.dataFormatValue) {
  209. case 0: this.value = 'NCHW'; break;
  210. case 1: this.value = 'NHWC'; break;
  211. default: throw new bigdl.Error(`Unsupported data format '${value.dataFormatValue}'.`);
  212. }
  213. break;
  214. }
  215. default: {
  216. throw new bigdl.Error(`Unsupported attribute data type '${value.dataType}'.`);
  217. }
  218. }
  219. }
  220. };
  221. bigdl.Tensor = class {
  222. constructor(tensor /*, tensors */) {
  223. this.type = new bigdl.TensorType(tensor.datatype, new bigdl.TensorShape(tensor.size));
  224. /*
  225. if (tensor && tensor.id && tensors && tensors[tensor.id] && tensors[tensor.id].tensorValue && tensors[tensor.id].tensorValue.storage) {
  226. const storage = tensors[tensor.id].tensorValue.storage;
  227. switch (this.type.dataType) {
  228. case 'float32':
  229. if (storage.bytes_data && storage.bytes_data.length > 0) {
  230. this.values = storage.bytes_data[0];
  231. this.encoding = '<';
  232. }
  233. else if (storage.float_data && storage.float_data.length > 0) {
  234. this.values = storage.float_data;
  235. this.encoding = '|';
  236. }
  237. break;
  238. default:
  239. break;
  240. }
  241. }
  242. */
  243. }
  244. };
  245. bigdl.TensorType = class {
  246. constructor(dataType, shape) {
  247. switch (dataType) {
  248. case bigdl.proto.DataType.FLOAT: this.dataType = 'float32'; break;
  249. case bigdl.proto.DataType.DOUBLE: this.dataType = 'float64'; break;
  250. default: throw new bigdl.Error(`Unsupported tensor type '${dataType}'.`);
  251. }
  252. this.shape = shape;
  253. }
  254. toString() {
  255. return (this.dataType || '?') + this.shape.toString();
  256. }
  257. };
  258. bigdl.TensorShape = class {
  259. constructor(dimensions) {
  260. this.dimensions = dimensions;
  261. if (!dimensions.every((dimension) => Number.isInteger(dimension))) {
  262. throw new bigdl.Error(`Invalid tensor shape '${JSON.stringify(dimensions)}'.`);
  263. }
  264. }
  265. toString() {
  266. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  267. }
  268. };
  269. bigdl.Error = class extends Error {
  270. constructor(message) {
  271. super(message);
  272. this.name = 'Error loading BigDL model.';
  273. }
  274. };
  275. export const ModelFactory = bigdl.ModelFactory;