bigdl.js 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. // Experimental
  2. const bigdl = {};
  3. bigdl.ModelFactory = class {
  4. async match(context) {
  5. const tags = await context.tags('pb');
  6. if (tags.has(2) && tags.has(7) && tags.has(8) &&
  7. tags.has(9) && tags.has(10) && tags.has(11) && tags.has(12)) {
  8. return context.set('bigdl');
  9. }
  10. return null;
  11. }
  12. async open(context) {
  13. bigdl.proto = await context.require('./bigdl-proto');
  14. bigdl.proto = bigdl.proto.com.intel.analytics.bigdl.serialization;
  15. let module = null;
  16. try {
  17. // https://github.com/intel-analytics/BigDL/blob/master/spark/dl/src/main/resources/serialization/bigdl.proto
  18. const reader = await context.read('protobuf.binary');
  19. module = bigdl.proto.BigDLModule.decode(reader);
  20. } catch (error) {
  21. const message = error && error.message ? error.message : error.toString();
  22. throw new bigdl.Error(`File format is not bigdl.BigDLModule (${message.replace(/\.$/, '')}).`);
  23. }
  24. const metadata = await context.metadata('bigdl-metadata.json');
  25. return new bigdl.Model(metadata, module);
  26. }
  27. };
  28. bigdl.Model = class {
  29. constructor(metadata, module) {
  30. const version = module && module.version ? module.version : '';
  31. this.format = `BigDL${version ? ` v${version}` : ''}`;
  32. this.modules = [new bigdl.Graph(metadata, module)];
  33. }
  34. };
  35. bigdl.Graph = class {
  36. constructor(metadata, module) {
  37. this.inputs = [];
  38. this.outputs = [];
  39. this.nodes = [];
  40. this.description = module.moduleType;
  41. const tensors = module.attr && module.attr.global_storage && module.attr.global_storage.nameAttrListValue && module.attr.global_storage.nameAttrListValue.attr ? module.attr.global_storage.nameAttrListValue.attr : {};
  42. const values = new Map();
  43. values.map = (name) => {
  44. if (!values.has(name)) {
  45. values.set(name, new bigdl.Value(name));
  46. }
  47. return values.get(name);
  48. };
  49. const loadModule = (metadata, module, tensors) => {
  50. switch (module.moduleType) {
  51. case 'com.intel.analytics.bigdl.nn.StaticGraph':
  52. case 'com.intel.analytics.bigdl.nn.Sequential': {
  53. for (const submodule of module.subModules) {
  54. loadModule(metadata, submodule, tensors);
  55. }
  56. break;
  57. }
  58. case 'com.intel.analytics.bigdl.nn.Input': {
  59. const argument = new bigdl.Argument(module.name, [values.map(module.name)]);
  60. this.inputs.push(argument);
  61. break;
  62. }
  63. default: {
  64. const node = new bigdl.Node(metadata, module, tensors, values);
  65. this.nodes.push(node);
  66. break;
  67. }
  68. }
  69. };
  70. loadModule(metadata, module, tensors);
  71. }
  72. };
  73. bigdl.Argument = class {
  74. constructor(name, value, type = null) {
  75. this.name = name;
  76. this.value = value;
  77. this.type = type;
  78. }
  79. };
  80. bigdl.Value = class {
  81. constructor(name, type, initializer) {
  82. if (typeof name !== 'string') {
  83. throw new bigdl.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  84. }
  85. this.name = name;
  86. this.type = !type && initializer ? initializer.type : type;
  87. this.initializer = initializer;
  88. }
  89. };
  90. bigdl.Node = class {
  91. constructor(metadata, module, tensors, values) {
  92. const type = module.moduleType;
  93. this.name = module.name;
  94. this.attributes = [];
  95. this.inputs = [];
  96. this.outputs = [];
  97. this.inputs.push(new bigdl.Argument('input', module.preModules.map((id) => values.map(id))));
  98. this.type = metadata.type(type) || { name: type };
  99. const inputs = this.type && this.type.inputs ? this.type.inputs.slice() : [];
  100. inputs.shift();
  101. if (module.weight) {
  102. inputs.shift();
  103. this.inputs.push(new bigdl.Argument('weight', [
  104. new bigdl.Value('', null, new bigdl.Tensor(module.weight, tensors))
  105. ]));
  106. }
  107. if (module.bias) {
  108. inputs.shift();
  109. this.inputs.push(new bigdl.Argument('bias', [
  110. new bigdl.Value('', null, new bigdl.Tensor(module.bias, tensors))
  111. ]));
  112. }
  113. if (module.parameters && module.parameters.length > 0) {
  114. for (const parameter of module.parameters) {
  115. const input = inputs.shift();
  116. const inputName = input ? input.name : this.inputs.length.toString();
  117. this.inputs.push(new bigdl.Argument(inputName, [
  118. new bigdl.Value('', null, new bigdl.Tensor(parameter, tensors))
  119. ]));
  120. }
  121. }
  122. for (const [key, obj] of Object.entries(module.attr)) {
  123. if (key === 'module_numerics' || key === 'module_tags') {
  124. continue;
  125. }
  126. if (obj.dataType === bigdl.proto.DataType.TENSOR) {
  127. if (obj.value) {
  128. this.inputs.push(new bigdl.Argument(key, [new bigdl.Value('', null, new bigdl.Tensor(obj.tensorValue, tensors))]));
  129. }
  130. continue;
  131. }
  132. if (obj.dataType === bigdl.proto.DataType.REGULARIZER && obj.value === undefined) {
  133. continue;
  134. }
  135. if (obj.dataType === bigdl.proto.DataType.ARRAY_VALUE && obj.arrayValue.datatype === bigdl.proto.DataType.TENSOR) {
  136. this.inputs.push(new bigdl.Argument(key, obj.arrayValue.tensor.map((tensor) => new bigdl.Value('', null, new bigdl.Tensor(tensor, tensors)))));
  137. continue;
  138. }
  139. let type = null;
  140. let value = null;
  141. switch (obj.dataType) {
  142. case bigdl.proto.DataType.INT32: {
  143. type = 'int32';
  144. value = obj.int32Value;
  145. break;
  146. }
  147. case bigdl.proto.DataType.FLOAT: {
  148. type = 'float32';
  149. value = obj.floatValue;
  150. break;
  151. }
  152. case bigdl.proto.DataType.DOUBLE: {
  153. type = 'float64';
  154. value = obj.doubleValue;
  155. break;
  156. }
  157. case bigdl.proto.DataType.BOOL: {
  158. type = 'boolean';
  159. value = obj.boolValue;
  160. break;
  161. }
  162. case bigdl.proto.DataType.REGULARIZER: {
  163. value = obj.value;
  164. break;
  165. }
  166. case bigdl.proto.DataType.MODULE: {
  167. value = obj.bigDLModule;
  168. break;
  169. }
  170. case bigdl.proto.DataType.NAME_ATTR_LIST: {
  171. value = value.nameAttrListValue;
  172. break;
  173. }
  174. case bigdl.proto.DataType.ARRAY_VALUE: {
  175. switch (obj.arrayValue.datatype) {
  176. case bigdl.proto.DataType.INT32: {
  177. type = 'int32[]';
  178. value = obj.arrayValue.i32;
  179. break;
  180. }
  181. case bigdl.proto.DataType.FLOAT: {
  182. type = 'float32[]';
  183. value = obj.arrayValue.flt;
  184. break;
  185. }
  186. case bigdl.proto.DataType.STRING: {
  187. type = 'string[]';
  188. value = obj.arrayValue.str;
  189. break;
  190. }
  191. case bigdl.proto.DataType.TENSOR: {
  192. type = 'tensor[]';
  193. value = obj.arrayValue.tensor;
  194. break;
  195. }
  196. default: {
  197. throw new bigdl.Error(`Unsupported attribute array data type '${obj.arrayValue.datatype}'.`);
  198. }
  199. }
  200. break;
  201. }
  202. case bigdl.proto.DataType.DATA_FORMAT: {
  203. switch (obj.dataFormatValue) {
  204. case 0: value = 'NCHW'; break;
  205. case 1: value = 'NHWC'; break;
  206. default: throw new bigdl.Error(`Unsupported data format '${obj.dataFormatValue}'.`);
  207. }
  208. break;
  209. }
  210. default: {
  211. throw new bigdl.Error(`Unsupported attribute data type '${obj.dataType}'.`);
  212. }
  213. }
  214. const argument = new bigdl.Argument(key, value, type);
  215. this.attributes.push(argument);
  216. }
  217. const output = this.name || this.type + module.namePostfix;
  218. this.outputs.push(new bigdl.Argument('output', [values.map(output)]));
  219. }
  220. };
  221. bigdl.Tensor = class {
  222. constructor(tensor /*, tensors */) {
  223. this.type = new bigdl.TensorType(tensor.datatype, new bigdl.TensorShape(tensor.size));
  224. /*
  225. if (tensor && tensor.id && tensors && tensors[tensor.id] && tensors[tensor.id].tensorValue && tensors[tensor.id].tensorValue.storage) {
  226. const storage = tensors[tensor.id].tensorValue.storage;
  227. switch (this.type.dataType) {
  228. case 'float32':
  229. if (storage.bytes_data && storage.bytes_data.length > 0) {
  230. this.values = storage.bytes_data[0];
  231. this.encoding = '<';
  232. }
  233. else if (storage.float_data && storage.float_data.length > 0) {
  234. this.values = storage.float_data;
  235. this.encoding = '|';
  236. }
  237. break;
  238. default:
  239. break;
  240. }
  241. }
  242. */
  243. }
  244. };
  245. bigdl.TensorType = class {
  246. constructor(dataType, shape) {
  247. switch (dataType) {
  248. case bigdl.proto.DataType.FLOAT: this.dataType = 'float32'; break;
  249. case bigdl.proto.DataType.DOUBLE: this.dataType = 'float64'; break;
  250. default: throw new bigdl.Error(`Unsupported tensor type '${dataType}'.`);
  251. }
  252. this.shape = shape;
  253. }
  254. toString() {
  255. return (this.dataType || '?') + this.shape.toString();
  256. }
  257. };
  258. bigdl.TensorShape = class {
  259. constructor(dimensions) {
  260. this.dimensions = dimensions;
  261. if (!dimensions.every((dimension) => Number.isInteger(dimension))) {
  262. throw new bigdl.Error(`Invalid tensor shape '${JSON.stringify(dimensions)}'.`);
  263. }
  264. }
  265. toString() {
  266. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  267. }
  268. };
  269. bigdl.Error = class extends Error {
  270. constructor(message) {
  271. super(message);
  272. this.name = 'Error loading BigDL model.';
  273. }
  274. };
  275. export const ModelFactory = bigdl.ModelFactory;