cntk.js 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073
  1. import * as base from './base.js';
  2. const cntk = {};
  3. cntk.ModelFactory = class {
  4. match(context) {
  5. const stream = context.stream;
  6. // CNTK v1
  7. const signature = [ 0x42, 0x00, 0x43, 0x00, 0x4e, 0x00, 0x00, 0x00 ];
  8. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
  9. context.type = 'cntk.v1';
  10. return;
  11. }
  12. // CNTK v2
  13. const tags = context.tags('pb');
  14. if (tags.get(1) === 0 && tags.get(2) === 2) {
  15. context.type = 'cntk.v2';
  16. return;
  17. }
  18. }
  19. async open(context) {
  20. const metadata = await context.metadata('cntk-metadata.json');
  21. switch (context.type) {
  22. case 'cntk.v1': {
  23. let obj = null;
  24. try {
  25. const stream = context.stream;
  26. const buffer = stream.peek();
  27. obj = new cntk.ComputationNetwork(buffer);
  28. } catch (error) {
  29. const message = error && error.message ? error.message : error.toString();
  30. throw new cntk.Error(`File format is not CNTK v1 (${message.replace(/\.$/, '')}).`);
  31. }
  32. return new cntk.Model(metadata, 1, obj);
  33. }
  34. case 'cntk.v2': {
  35. cntk.proto = await context.require('./cntk-proto');
  36. cntk.proto = cntk.proto.CNTK.proto;
  37. cntk.proto.PoolingType = { 0: 'Max', 1: 'Average' };
  38. let obj = null;
  39. try {
  40. const reader = context.read('protobuf.binary');
  41. const dictionary = cntk.proto.Dictionary.decode(reader);
  42. obj = cntk.ModelFactory._convertDictionary(dictionary);
  43. } catch (error) {
  44. const message = error && error.message ? error.message : error.toString();
  45. throw new cntk.Error(`File format is not cntk.Dictionary (${message.replace(/\.$/, '')}).`);
  46. }
  47. return new cntk.Model(metadata, 2, obj);
  48. }
  49. default: {
  50. throw new cntk.Error(`Unsupported CNTK format '${context.type}'.`);
  51. }
  52. }
  53. }
  54. static _convertDictionary(dictionary) {
  55. const target = {};
  56. for (const key of Object.keys(dictionary.data).filter((key) => key != 'version')) {
  57. target[key] = cntk.ModelFactory._convertDictionaryValue(dictionary.data[key]);
  58. }
  59. return target;
  60. }
  61. static _convertDictionaryValue(dictionaryValue) {
  62. switch (dictionaryValue.value_type) {
  63. case cntk.proto.DictionaryValue.Type.Bool:
  64. return dictionaryValue.bool_value;
  65. case cntk.proto.DictionaryValue.Type.Int:
  66. return dictionaryValue.int_value;
  67. case cntk.proto.DictionaryValue.Type.SizeT:
  68. return dictionaryValue.size_t_value;
  69. case cntk.proto.DictionaryValue.Type.Float:
  70. return dictionaryValue.float_value;
  71. case cntk.proto.DictionaryValue.Type.Double:
  72. return dictionaryValue.double_value;
  73. case cntk.proto.DictionaryValue.Type.String:
  74. return dictionaryValue.string_value;
  75. case cntk.proto.DictionaryValue.Type.Vector:
  76. return cntk.ModelFactory._convertVectorValue(dictionaryValue.vector_value);
  77. case cntk.proto.DictionaryValue.Type.NDShape:
  78. return dictionaryValue.nd_shape_value;
  79. case cntk.proto.DictionaryValue.Type.Axis:
  80. return dictionaryValue.axis_value;
  81. case cntk.proto.DictionaryValue.Type.Dictionary:
  82. return cntk.ModelFactory._convertDictionary(dictionaryValue.dictionary_value);
  83. case cntk.proto.DictionaryValue.Type.NDArrayView:
  84. return dictionaryValue.nd_array_view_value;
  85. default:
  86. throw new cntk.Error(`Unsupported dictionary value type '${dictionaryValue.value_type}'.`);
  87. }
  88. }
  89. static _convertVectorValue(vectorValue) {
  90. return vectorValue.value.map((item) => {
  91. return cntk.ModelFactory._convertDictionaryValue(item);
  92. });
  93. }
  94. };
  95. cntk.Model = class {
  96. constructor(metadata, version, obj) {
  97. switch (version) {
  98. case 1:
  99. this.format = `CNTK v1${obj.version ? (`.${obj.version}`) : ''}`;
  100. break;
  101. case 2:
  102. this.format = 'CNTK v2';
  103. break;
  104. default:
  105. throw new cntk.Error(`Unsupported CNTK version '${version}'.`);
  106. }
  107. this.graphs = [ new cntk.Graph(metadata, version, obj) ];
  108. }
  109. };
  110. cntk.Graph = class {
  111. constructor(metadata, version, obj) {
  112. metadata = new cntk.GraphMetadata(metadata);
  113. this.inputs = [];
  114. this.outputs = [];
  115. this.nodes = [];
  116. const values = new Map();
  117. values.map = (name, version, obj) => {
  118. if (obj && values.has(name)) {
  119. throw new cntk.Error(`Duplicate value '${name}'.`);
  120. }
  121. if (!values.has(name)) {
  122. switch (version) {
  123. case 1:
  124. values.set(name, new cntk.Value(version, obj ? obj : { name: name }));
  125. break;
  126. case 2:
  127. values.set(name, new cntk.Value(version, obj ? obj : { uid: name }));
  128. break;
  129. default:
  130. throw new cntk.Error(`Unsupported CNTK version '${version}'.`);
  131. }
  132. }
  133. return values.get(name);
  134. };
  135. switch (version) {
  136. case 1: {
  137. for (const name of Object.keys(obj.nodes)) {
  138. const node = obj.nodes[name];
  139. switch (node.__type__) {
  140. case 'InputValue': {
  141. const argument = new cntk.Argument(node.name, [ values.map(node.name, version, node) ]);
  142. this.inputs.push(argument);
  143. break;
  144. }
  145. case 'LearnableParameter': {
  146. values.map(node.name, version, node);
  147. break;
  148. }
  149. default:
  150. break;
  151. }
  152. }
  153. for (const name of Object.keys(obj.nodes)) {
  154. const node = obj.nodes[name];
  155. if (node.__type__ != 'InputValue' && node.__type__ != 'LearnableParameter') {
  156. this.nodes.push(new cntk.Node(metadata, version, node, values));
  157. }
  158. }
  159. if (obj.output) {
  160. for (const output of obj.output) {
  161. const argument = new cntk.Argument(output, [ values.map(output, version) ]);
  162. this.outputs.push(argument);
  163. }
  164. }
  165. break;
  166. }
  167. case 2: {
  168. const map = new Map(obj.primitive_functions.map((node) => [ node.uid, node ]));
  169. for (const input of obj.inputs) {
  170. const value = values.map(input.uid, version, input);
  171. // VariableKind { 0: 'input', 1: 'output', 2: 'parameter', 3: 'constant', 4: 'placeholder' }
  172. if (input.kind == 0) {
  173. const inputName = input.name || input.uid;
  174. this.inputs.push(new cntk.Argument(inputName, [ value ]));
  175. }
  176. }
  177. for (const block of obj.primitive_functions) {
  178. if (block.op == 57 && block.block_function_composite) {
  179. const list = [ block.block_function_composite.root ];
  180. const output = map.get(block.block_function_composite.root);
  181. const keys = block.block_function_composite_arguments_map_keys;
  182. const args = block.block_function_composite_arguments_map_values;
  183. block.inputs = args;
  184. if (!Array.isArray(keys) || !Array.isArray(args) || keys.length !== args.length) {
  185. throw new cntk.Error('Invalid block function composite arguments.');
  186. }
  187. const inputs = keys.map((key) => new cntk.Argument(key, [ values.map(key, version) ]));
  188. const outputs = [ new cntk.Argument('output', [ values.map(`${output.uid}_Output_0`, version) ]) ];
  189. const nodes = [];
  190. while (list.length > 0) {
  191. const name = list.shift();
  192. if (map.has(name)) {
  193. const node = map.get(name);
  194. nodes.push(new cntk.Node(metadata, version, node, values));
  195. map.delete(name);
  196. for (let i = 0; i < node.inputs.length; i++) {
  197. const parts = node.inputs[i].split('_');
  198. if (parts.length >= 3) {
  199. parts.pop();
  200. if (parts.pop() == 'Output') {
  201. list.push(parts.join('_'));
  202. }
  203. }
  204. }
  205. }
  206. }
  207. const func = new cntk.Function(block.block_function_op_name, nodes, inputs, outputs);
  208. metadata.add(block.uid, func);
  209. }
  210. }
  211. for (const node of map.values()) {
  212. this.nodes.push(new cntk.Node(metadata, version, node, values));
  213. }
  214. break;
  215. }
  216. default: {
  217. throw new cntk.Error(`Unsupported graph version '${version}'.`);
  218. }
  219. }
  220. }
  221. };
  222. cntk.Argument = class {
  223. constructor(name, value) {
  224. this.name = name;
  225. this.value = value;
  226. }
  227. };
  228. cntk.Value = class {
  229. constructor(version, obj) {
  230. switch (version) {
  231. case 1:
  232. switch (obj.__type__) {
  233. case 'InputValue':
  234. this.name = obj.name;
  235. this.type = new cntk.TensorType(version, obj.precision, obj.sampleLayout);
  236. this.initializer = null;
  237. break;
  238. case 'LearnableParameter':
  239. this.name = obj.name;
  240. this.initializer = new cntk.Tensor(version, obj);
  241. this.type = this.initializer.type;
  242. break;
  243. default:
  244. this.name = obj.name;
  245. this.type = null;
  246. this.initializer = null;
  247. break;
  248. }
  249. break;
  250. case 2:
  251. if (obj.value) {
  252. this.name = obj.name || obj.uid;
  253. this.type = null;
  254. this.initializer = new cntk.Tensor(version, obj);
  255. } else {
  256. this.name = obj.uid;
  257. if (obj.data_type && obj.shape) {
  258. this.type = new cntk.TensorType(version, obj.data_type, obj.shape);
  259. }
  260. this.initializer = null;
  261. }
  262. break;
  263. default:
  264. throw new cntk.Error(`Unsupported CNTK version '${version}'.`);
  265. }
  266. }
  267. };
  268. cntk.Node = class {
  269. constructor(metadata, version, obj, values) {
  270. this.attributes = [];
  271. this.inputs = [];
  272. this.outputs = [];
  273. let inputs = [];
  274. let outputs = [];
  275. switch (version) {
  276. case 1: {
  277. const type = obj.__type__;
  278. this.type = metadata.type(type) || { name: type };
  279. this.name = obj.name;
  280. for (const [name, value] of Object.entries(obj)) {
  281. if (name != '__type__' && name != 'name' && name != 'inputs' && name != 'precision') {
  282. const attribute = new cntk.Attribute(metadata.attribute(type, name), name, value);
  283. this.attributes.push(attribute);
  284. }
  285. }
  286. inputs = obj.inputs.map((input) => values.map(input, version));
  287. outputs = [ values.map(this.name, version) ];
  288. break;
  289. }
  290. case 2: {
  291. this.name = obj.name || obj.uid || null;
  292. const output = obj.uid;
  293. if (obj.op == 57) {
  294. this.type = metadata.type(obj.uid) || { name: obj.uid };
  295. } else if (Object.prototype.hasOwnProperty.call(obj, 'op')) {
  296. // cntk/Source/CNTKv2LibraryDll/API/Internals/PrimitiveOpType.h
  297. this.type = metadata.type(obj.op.toNumber());
  298. } else {
  299. const type = obj.type;
  300. this.type = metadata.type(type) || { name: type };
  301. if (obj.user_defined_state) {
  302. for (const [name, value] of Object.entries(obj.user_defined_state)) {
  303. const attribute = new cntk.Attribute(metadata.attribute(type, name), name, value);
  304. this.attributes.push(attribute);
  305. }
  306. }
  307. }
  308. if (obj.attributes) {
  309. for (const [name, value] of Object.entries(obj.attributes)) {
  310. const attribute = new cntk.Attribute(metadata.attribute(this.type, name), name, value);
  311. this.attributes.push(attribute);
  312. }
  313. }
  314. inputs = obj.inputs.map((input) => values.map(input, version));
  315. outputs.push(values.map(`${output}_Output_0`, version));
  316. break;
  317. }
  318. default: {
  319. throw new cntk.Error(`Unsupported CNTK version '${version}'.`);
  320. }
  321. }
  322. let inputIndex = 0;
  323. if (this.type && this.type.inputs) {
  324. for (const inputSchema of this.type.inputs) {
  325. if (inputIndex < inputs.length || inputSchema.option != 'optional') {
  326. const inputCount = inputSchema.type === 'Tensor[]' ? (inputs.length - inputIndex) : 1;
  327. const inputArguments = [];
  328. for (const inputArgument of inputs.slice(inputIndex, inputIndex + inputCount)) {
  329. if (inputArgument.name != '' || inputSchema.option != 'optional') {
  330. inputArguments.push(inputArgument);
  331. }
  332. }
  333. this.inputs.push(new cntk.Argument(inputSchema.name, inputArguments));
  334. inputIndex += inputCount;
  335. }
  336. }
  337. }
  338. this.inputs.push(...inputs.slice(inputIndex).map((argument, index) => {
  339. return new cntk.Argument((inputIndex + index).toString(), [ argument ]);
  340. }));
  341. let outputIndex = 0;
  342. if (this.type && this.type.outputs) {
  343. for (const outputSchema of this.type.outputs) {
  344. if (outputIndex < outputs.length || !outputSchema.optional) {
  345. const outputCount = outputSchema.type === 'Tensor[]' ? (outputs.length - outputIndex) : 1;
  346. this.outputs.push(new cntk.Argument(outputSchema.name, outputs.slice(outputIndex, outputIndex + outputCount)));
  347. outputIndex += outputCount;
  348. }
  349. }
  350. }
  351. this.outputs.push(...outputs.slice(outputIndex).map((argument) => {
  352. return new cntk.Argument(outputIndex.toString(), [ argument ]);
  353. }));
  354. }
  355. };
  356. cntk.Attribute = class {
  357. constructor(metadata, name, value) {
  358. this.name = name;
  359. this.value = value;
  360. this.type = null;
  361. if (this.value && this.value.__type__ === 'shape') {
  362. this.value = new cntk.TensorShape(1, value);
  363. this.type = 'shape';
  364. }
  365. if (cntk.proto && this.value instanceof cntk.proto.NDShape) {
  366. this.value = new cntk.TensorShape(2, value);
  367. this.type = 'shape';
  368. }
  369. if (cntk.proto && this.value instanceof cntk.proto.Axis) {
  370. const axis = { __type__: 'Axis' };
  371. for (const key of Object.keys(value).filter((key) => key !== 'name')) {
  372. axis[key] = value[key];
  373. }
  374. this.value = axis;
  375. }
  376. if (metadata) {
  377. if (metadata.type) {
  378. this.type = metadata.type;
  379. const type = cntk[this.type] || cntk.proto[this.type];
  380. if (type && type[this.value]) {
  381. this.value = type[this.value];
  382. }
  383. }
  384. if (metadata.visible === false) {
  385. this.visible = false;
  386. } else if (Object.prototype.hasOwnProperty.call(metadata, 'default')) {
  387. let defaultValue = metadata.default;
  388. value = this.value;
  389. if (typeof value == 'function') {
  390. value = value();
  391. }
  392. if (this.type == 'shape') {
  393. value = value.dimensions;
  394. }
  395. if (value == defaultValue) {
  396. this.visible = false;
  397. } else if (Array.isArray(value) && Array.isArray(defaultValue)) {
  398. defaultValue = defaultValue.slice(0, defaultValue.length);
  399. if (defaultValue.length > 1 && defaultValue[defaultValue.length - 1] == null) {
  400. defaultValue.pop();
  401. while (defaultValue.length < value.length) {
  402. defaultValue.push(defaultValue[defaultValue.length - 1]);
  403. }
  404. }
  405. if (value.every((item, index) => item == defaultValue[index])) {
  406. this.visible = false;
  407. }
  408. }
  409. }
  410. }
  411. }
  412. };
  413. cntk.Tensor = class {
  414. constructor(version, tensor) {
  415. this.encoding = '|';
  416. this.values = null;
  417. switch (version) {
  418. case 1: {
  419. if (tensor.__type__ == 'LearnableParameter') {
  420. this.name = tensor.name || null;
  421. this.type = new cntk.TensorType(version, tensor.precision, tensor.sampleLayout);
  422. }
  423. break;
  424. }
  425. case 2: {
  426. this.name = tensor.name || tensor.uid || null;
  427. this.type = new cntk.TensorType(version, tensor.data_type, tensor.shape);
  428. const value = tensor.value;
  429. if (this.type.dataType === 'float32' && value && value.float_values && value.float_values.value && value.float_values.value.length > 0) {
  430. this.values = value.float_values.value;
  431. }
  432. break;
  433. }
  434. default:
  435. throw new cntk.Error(`Unsupported CNTK version '${version}'.`);
  436. }
  437. }
  438. };
  439. cntk.TensorType = class {
  440. constructor(version, dataType, shape) {
  441. this.dataType = '?';
  442. switch (version) {
  443. case 1:
  444. switch (dataType) {
  445. case 'float': this.dataType = 'float32'; break;
  446. case 'double': this.dataType = 'float64'; break;
  447. case 'half': this.dataType = 'float16'; break;
  448. case '': this.dataType = 'float32'; break;
  449. default: throw new cntk.Error(`Unsupported tensor data type '${dataType}'.`);
  450. }
  451. this.shape = new cntk.TensorShape(version, shape);
  452. break;
  453. case 2:
  454. dataType = dataType.toNumber();
  455. switch (dataType) {
  456. case 1: this.dataType = 'float32'; break;
  457. default: throw new cntk.Error(`Unsupported tensor data type '${dataType}'.`);
  458. }
  459. this.shape = new cntk.TensorShape(version, shape);
  460. break;
  461. default:
  462. throw new cntk.Error(`Unsupported CNTK version '${version}'.`);
  463. }
  464. }
  465. toString() {
  466. return this.dataType + this.shape.toString();
  467. }
  468. };
  469. cntk.TensorShape = class {
  470. constructor(version, shape) {
  471. switch (version) {
  472. case 1:
  473. this.dimensions = shape.dims;
  474. break;
  475. case 2:
  476. this.dimensions = shape.shape_dim.map((dimension) => dimension.toNumber());
  477. break;
  478. default:
  479. throw new cntk.Error(`Unsupported CNTK version '${version}'.`);
  480. }
  481. }
  482. toString() {
  483. return (this.dimensions && this.dimensions.length) ? (`[${this.dimensions.join(',')}]`) : '';
  484. }
  485. };
  486. cntk.Function = class {
  487. constructor(name, nodes, inputs, outputs) {
  488. this.type = 'function';
  489. this.name = name;
  490. this.inputs = inputs;
  491. this.outputs = outputs;
  492. this.nodes = nodes;
  493. switch (this.name) {
  494. case 'PReLU':
  495. case 'Softmax':
  496. this.category = 'Activation';
  497. break;
  498. case 'Dropout':
  499. this.category = 'Dropout';
  500. break;
  501. case 'Convolution':
  502. case 'ConvolutionTranspose':
  503. case 'Dense':
  504. case 'linear':
  505. case 'LSTM':
  506. this.category = 'Layer';
  507. break;
  508. case 'BatchNormalization':
  509. case 'lrn':
  510. this.category = 'Normalization';
  511. break;
  512. case 'AveragePooling':
  513. case 'MaxPooling':
  514. this.category = 'Pool';
  515. break;
  516. default:
  517. this.category = null;
  518. break;
  519. }
  520. }
  521. };
  522. cntk.GraphMetadata = class {
  523. constructor(metadata) {
  524. this._metadata = metadata;
  525. this._functions = new Map();
  526. this._attributes = new Map();
  527. }
  528. add(name, func) {
  529. if (this._functions.has(name)) {
  530. throw new cntk.Error(`Duplicate function identifier '${func.name}'.`);
  531. }
  532. this._functions.set(name, func);
  533. }
  534. name(code) {
  535. // cntk/Source/CNTKv2LibraryDll/API/Internals/PrimitiveOpType.h
  536. return this._metadata.name(code);
  537. }
  538. type(name) {
  539. if (this._functions.has(name)) {
  540. return this._functions.get(name);
  541. }
  542. return this._metadata.type(name);
  543. }
  544. attribute(type, name) {
  545. const key = `${type}:${name}`;
  546. if (!this._attributes.has(key)) {
  547. const metadata = this.type(type);
  548. if (metadata && metadata.attributes && metadata.attributes.length > 0) {
  549. for (const attribute of metadata.attributes) {
  550. this._attributes.set(`${type}:${attribute.name}`, attribute);
  551. }
  552. }
  553. if (!this._attributes.has(key)) {
  554. this._attributes.set(key, null);
  555. }
  556. }
  557. return this._attributes.get(key);
  558. }
  559. };
  560. cntk.ComputationNetwork = class {
  561. constructor(buffer) {
  562. const reader = new base.BinaryReader(buffer);
  563. reader.match = function(text) {
  564. const position = this.position;
  565. for (let i = 0; i < text.length; i++) {
  566. if (this.uint16() != text.charCodeAt(i)) {
  567. this.seek(position);
  568. return false;
  569. }
  570. }
  571. if (this.uint16() != 0) {
  572. this.seek(position);
  573. return false;
  574. }
  575. return true;
  576. };
  577. reader.assert = function(text) {
  578. if (!this.match(text)) {
  579. throw new cntk.Error(`Invalid '${text}' signature.`);
  580. }
  581. };
  582. reader.string = function() {
  583. const content = [];
  584. let c = this.uint16();
  585. while (c != 0) {
  586. content.push(String.fromCharCode(c));
  587. c = this.uint16();
  588. }
  589. return content.join('');
  590. };
  591. reader.strings = function() {
  592. const count = this.uint64();
  593. const array = new Array(count);
  594. for (let i = 0; i < count; i++) {
  595. array[i] = this.string();
  596. }
  597. return array;
  598. };
  599. reader.booleans = function() {
  600. const count = this.uint64();
  601. const array = new Array(count);
  602. for (let i = 0; i < count; i++) {
  603. array[i] = this.boolean();
  604. }
  605. return array;
  606. };
  607. reader.matrix = function () {
  608. const type = this.byte();
  609. switch (type) {
  610. case 100: {
  611. // dense
  612. this.assert('BMAT');
  613. const elsize = this.uint64();
  614. const value = {};
  615. value.name = this.string();
  616. value.format = this.uint32();
  617. value.rows = this.uint64();
  618. value.columns = this.uint64();
  619. this.read(elsize * value.rows * value.columns);
  620. this.assert('EMAT');
  621. return value;
  622. }
  623. case 115: // sparse
  624. throw new cntk.Error('Matrix sparse type not implemented.');
  625. default:
  626. throw new cntk.Error(`Matrix type '${type}' not implemented.`);
  627. }
  628. };
  629. reader.shape = function(acceptLegacyFormat) {
  630. const dims = [];
  631. const rank = this.uint32();
  632. let dim0 = 0;
  633. if (rank > 0) {
  634. dim0 = this.uint32();
  635. }
  636. if (!acceptLegacyFormat || dim0 !== 0) {
  637. if (rank > 0) {
  638. dims.push(dim0);
  639. }
  640. for (let i = 1; i < rank; i++) {
  641. dims.push(this.uint32());
  642. }
  643. } else {
  644. const dim = this.uint32();
  645. dims.push(this.uint32());
  646. dims.push(rank);
  647. dims.push(dim);
  648. }
  649. return { __type__: 'shape', dims: dims };
  650. };
  651. const shape = (dims) => {
  652. return { __type__: 'shape', dims: dims };
  653. };
  654. reader.assert('BCN');
  655. reader.assert('BVersion');
  656. this.version = reader.uint64();
  657. reader.assert('EVersion');
  658. const numNodes = reader.uint64();
  659. reader.assert('BNodeList');
  660. const op = {};
  661. op.Minus = function() {};
  662. op.Plus = function() {};
  663. op.GreaterEqual = function() {};
  664. op.Equal = function() {};
  665. op.NotEqual = function() {};
  666. op.GreaterEqual = function() {};
  667. op.Exp = function() {};
  668. op.Log = function() {};
  669. op.Reciprocal = function() {};
  670. op.ElementTimes = function() {};
  671. op.ClassificationError = function() {};
  672. op.RectifiedLinear = function() {};
  673. op.InputValue = function(reader, version) {
  674. this.rows = reader.uint64();
  675. this.cols = reader.uint64();
  676. this.sampleLayout = reader.shape(true);
  677. this.dynamicAxisNodeName = '';
  678. if (version >= 8) {
  679. const nrAxes = reader.uint32();
  680. if (nrAxes == 1) {
  681. this.dynamicAxisNodeName = reader.string();
  682. }
  683. }
  684. this.learningRateMultiplier = 0;
  685. if (version >= 10) {
  686. this.learningRateMultiplier = reader.float32();
  687. }
  688. };
  689. op.LearnableParameter = function(reader, version) {
  690. if (version >= 3) {
  691. this.learningRateMultiplier = reader.float32();
  692. this.sampleLayout = reader.shape(false);
  693. } else {
  694. throw new cntk.Error('LeanableParameter reader implemented.');
  695. }
  696. this.value = reader.matrix();
  697. };
  698. op.CrossEntropyWithSoftmax = function(reader) {
  699. this.evalMode = reader.uint32();
  700. if (this.evalMode > 2) {
  701. this.evalMode = 0;
  702. reader.skip(-4);
  703. }
  704. };
  705. op.Times = function(reader, version) {
  706. this.outputRank = (version >= 3) ? reader.uint64() : 1;
  707. this.inferInputRankToMap = (version >= 12) ? reader.int32() : -1;
  708. };
  709. op.Dropout = function(reader, version) {
  710. if (version >= 16) {
  711. this.rngSeed = (version == 16) ? reader.uint32() : reader.uint64();
  712. this.rngOffset = reader.uint64();
  713. }
  714. };
  715. op.ConvolutionBase = function(reader, version) {
  716. if (version >= 5) {
  717. this.kernelShape = reader.shape(false);
  718. this.mapCount = reader.shape(false);
  719. this.strides = reader.shape(false);
  720. this.sharing = reader.booleans();
  721. this.autoPadding = reader.booleans();
  722. this.lowerPad = reader.shape(false);
  723. this.upperPad = reader.shape(false);
  724. this.poolKind = reader.int32();
  725. this.imageLayoutKind = reader.int32();
  726. this.maxTempMemSizeInSamples = reader.uint64();
  727. }
  728. if (version >= 9) {
  729. this.transpose = reader.boolean();
  730. }
  731. if (version >= 20) {
  732. this.outputShape = reader.shape(false);
  733. }
  734. if (version >= 21) {
  735. this.ceilOutDim = reader.boolean();
  736. }
  737. if (version >= 23) {
  738. this.includePad = reader.boolean();
  739. }
  740. };
  741. op.Convolution = function(reader, version) {
  742. op.ConvolutionBase.apply(this, [ reader, version ]);
  743. if (version < 5) {
  744. this.kernelShape = shape([ reader.uint64(), reader.uint64(), 1 ]);
  745. this.strides = shape([ reader.uint64(), reader.uint64(), 1 ]);
  746. this.mapCount = shape([ reader.uint32() ]);
  747. this.imageLayoutKind = reader.int32();
  748. this.autoPadding = [ reader.boolean() ];
  749. this.maxTempMemSizeInSamples = reader.uint64();
  750. this.poolKind = 'None';
  751. this.convolution2D = true;
  752. this.sharing = [ true ];
  753. this.lowerPad = shape([ 0 ]);
  754. this.upperPad = shape([ 0 ]);
  755. } else {
  756. this.convolution2D = reader.boolean();
  757. if (version >= 18) {
  758. this.dilation = reader.shape();
  759. } else {
  760. this.dilation = shape([ 1 ]);
  761. }
  762. }
  763. };
  764. op.Pooling = function(reader, version) {
  765. op.ConvolutionBase.apply(this, [ reader, version ]);
  766. };
  767. op.PoolingBase = function(reader) {
  768. this.imageLayoutKind = reader.int32();
  769. this.windowWidth = reader.uint32();
  770. this.windowHeight = reader.uint64();
  771. this.horizontalSubsample = reader.uint64();
  772. this.verticalSubsample = reader.uint64();
  773. };
  774. op.MaxPooling = function(reader, version) {
  775. op.PoolingBase.apply(this, [ reader, version ]);
  776. };
  777. op.ROIPooling = function(reader, version) {
  778. this.roiOutputShape = reader.shape(false);
  779. this.poolKind = (version < 26) ? 'Max' : reader.int32();
  780. this.spatialScale = (version < 26) ? 0.0625 : reader.float64();
  781. };
  782. op.Reshape = function(reader) {
  783. this.beginDimParameter = reader.uint32();
  784. this.endDimParameter = reader.uint32();
  785. this.replacementSampleLayout = reader.shape(false);
  786. };
  787. op.ReduceElements = function(reader, version) {
  788. let num_axes = 1;
  789. if (version >= 27) {
  790. num_axes = reader.uint32();
  791. }
  792. this.axes = [];
  793. for (let i = 0; i < num_axes; i++) {
  794. this.axes.push(reader.uint32());
  795. }
  796. this.operation = reader.string();
  797. if (version >= 24) {
  798. this.keepDimensions = reader.boolean();
  799. }
  800. };
  801. op.BatchNormalization = function(reader, version) {
  802. let mbCount = 0;
  803. if (version >= 6) {
  804. this.spatial = reader.boolean();
  805. this.normalizationTimeConstant = reader.float64();
  806. this.blendTimeConstant = reader.float64();
  807. this.imageLayoutKind = reader.int32();
  808. if (version >= 13) {
  809. if (version != 19) {
  810. this.runCountUntied = reader.uint64();
  811. } else {
  812. this.runCountUntied = reader.boolean() ? 0 : 'SIZE_MAX'; // TODO
  813. }
  814. } else {
  815. mbCount = reader.uint64();
  816. }
  817. this.epsilon = reader.float64();
  818. this.useCntkEngine = reader.boolean();
  819. } else {
  820. const verWritten = reader.int32();
  821. const verReadable = reader.int32();
  822. if (verReadable > verWritten || verWritten < 0x00010001 || verReadable > 0x00010004) {
  823. throw new cntk.Error('BatchNormalization version not supported.');
  824. }
  825. this.eval = reader.boolean();
  826. this.spatial = reader.boolean();
  827. if (verWritten >= 0x00010004) {
  828. this.normalizationTimeConstant = reader.float64();
  829. } else {
  830. reader.float64(); // expAvgFactor
  831. }
  832. if (verWritten >= 0x00010002) {
  833. this.imageLayoutKind = reader.int32();
  834. mbCount = reader.uint64();
  835. }
  836. if (verWritten >= 0x00010003) {
  837. this.epsilon = reader.float64();
  838. this.useCntkEngine = reader.boolean();
  839. }
  840. }
  841. if (version < 13) {
  842. this.runCountUntied = 16 * mbCount;
  843. this.convertRunningVariancePending = true;
  844. }
  845. };
  846. op.Tanh = function() {};
  847. op.Sigmoid = function() {};
  848. op.Logistic = function() {};
  849. op.SquareError = function() {};
  850. op.ErrorPrediction = function() {};
  851. op.RowStack = function(reader, version) {
  852. this.spliceDim = (version >= 3) ? reader.int32() : 1;
  853. };
  854. op.Slice = function(reader, version) {
  855. let num = 1;
  856. if (version >= 22) {
  857. num = reader.int32();
  858. }
  859. this.index = [];
  860. this.axis = [];
  861. this.strideMultiplier = [];
  862. for (let i = 0; i < num; i++) {
  863. this.index.push([ [ reader.uint64(), reader.uint64() ] ]);
  864. if (version >= 3) {
  865. this.axis.push(reader.int32());
  866. }
  867. if (version >= 27) {
  868. this.strideMultiplier.push(reader.int32());
  869. }
  870. }
  871. };
  872. op.PastValue = function(reader, version) {
  873. this.timeStep = reader.int32();
  874. if (version > 3) {
  875. this.sampleLayout = reader.shape(false);
  876. } else {
  877. const rows = reader.uint64();
  878. reader.uint64();
  879. this.sampleLayout = shape([ rows ], true);
  880. }
  881. if (version >= 2) {
  882. this.initialStateValue = reader.int32();
  883. }
  884. };
  885. op.FutureValue = function(reader, version) {
  886. this.timeStep = reader.int32();
  887. if (version > 3) {
  888. this.sampleLayout = reader.shape(false);
  889. } else {
  890. const rows = reader.uint64();
  891. reader.uint64();
  892. this.sampleLayout = shape([ rows ], true);
  893. }
  894. if (version >= 2) {
  895. this.initialStateValue = reader.int32();
  896. }
  897. };
  898. op.TransposeDimensions = function(reader, version) {
  899. if (version >= 3) {
  900. this.axis1 = reader.int32();
  901. this.axis2 = reader.int32();
  902. if (version >= 25 && this.axis1 == 0 && this.axis2 == 0) {
  903. const size = reader.uint64();
  904. this.perm = [];
  905. for (let i = 0; i < size; i++) {
  906. this.perm.push(reader.uint64());
  907. }
  908. }
  909. } else {
  910. this.axis1 = 1;
  911. this.axis2 = 2;
  912. }
  913. };
  914. op.AveragePooling = function(reader, version) {
  915. op.PoolingBase.apply(this, [ reader, version ]);
  916. };
  917. op.InvStdDev = function(reader) {
  918. this.hasComputed = reader.boolean();
  919. this.value = reader.matrix();
  920. };
  921. op.Mean = function(reader) {
  922. this.hasComputed = reader.boolean();
  923. this.value = reader.matrix();
  924. };
  925. op.PerDimMeanVarNormalization = function() {};
  926. op.Softmax = function() {};
  927. op.DynamicAxis = function() {};
  928. const nodes = [];
  929. this.nodes = {};
  930. for (let i = 0; i < numNodes; i++) {
  931. const precision = this.version >= 7 ? reader.string() : '';
  932. if (precision != 'float' && precision != 'double' && precision != 'half' && precision != '') {
  933. throw new cntk.Error(`Invalid precision format '${precision}'.`);
  934. }
  935. const obj = { __type__: reader.string() };
  936. obj.name = reader.string();
  937. obj.precision = precision;
  938. const constructor = op[obj.__type__];
  939. if (!constructor) {
  940. throw new cntk.Error(`Unsupported node type '${obj.__type__}'.`);
  941. }
  942. constructor.apply(obj, [ reader, this.version ]);
  943. nodes.push(obj);
  944. this.nodes[obj.name] = obj;
  945. }
  946. reader.assert('ENodeList');
  947. reader.assert('BRelation');
  948. for (let j = 0; j < numNodes; j++) {
  949. const nodeName = reader.string();
  950. const node = this.nodes[nodeName];
  951. const numChildren = reader.uint64();
  952. const children = [];
  953. for (let k = 0; k < numChildren; k++) {
  954. children.push(reader.string());
  955. }
  956. if (this.version < 19 && node.__type__ == 'BatchNormalization') {
  957. const runSampleCount = {
  958. __type__: 'LearnableParameter',
  959. name: `${nodeName}.run_sample_count`,
  960. precision: node.precision,
  961. sampleLayout: shape([ 1 ]), // TODO set value = 0
  962. learningRateMultiplier: 0
  963. };
  964. nodes.push(runSampleCount);
  965. this.nodes[runSampleCount.name] = runSampleCount;
  966. children.push(runSampleCount.name);
  967. }
  968. if (node.__type__ == 'Convolution' && children.length > 1) {
  969. children.splice(0, 0, children.pop());
  970. }
  971. node.inputs = children;
  972. }
  973. reader.assert('ERelation');
  974. reader.assert('BRootNodes');
  975. if (reader.match('BFeatureNodes')) {
  976. this.feature = reader.strings();
  977. reader.assert('EFeatureNodes');
  978. }
  979. if (reader.match('BLabelNodes')) {
  980. this.label = reader.strings();
  981. reader.assert('ELabelNodes');
  982. }
  983. if (reader.match('BCriterionNodes')) {
  984. this.criterion = reader.strings();
  985. reader.assert('ECriterionNodes');
  986. }
  987. if (this.criterion.length == 0) {
  988. if (reader.match('BCriteriaNodes')) {
  989. this.criterion = reader.strings();
  990. reader.assert('ECriteriaNodes');
  991. }
  992. }
  993. if (reader.match('BNodesReqMultiSeqHandling')) {
  994. reader.strings();
  995. reader.assert('ENodesReqMultiSeqHandling');
  996. }
  997. if (reader.match('BEvalNodes')) {
  998. this.eval = reader.strings();
  999. reader.assert('EEvalNodes');
  1000. }
  1001. if (reader.match('BOutputNodes')) {
  1002. this.output = reader.strings();
  1003. reader.assert('EOutputNodes');
  1004. }
  1005. if (reader.match('BPairNodes')) {
  1006. this.pair = reader.strings();
  1007. reader.assert('EPairNodes');
  1008. }
  1009. reader.assert('ERootNodes');
  1010. reader.assert('ECN');
  1011. }
  1012. };
  1013. cntk.ImageLayoutKind = {
  1014. 0: 'CHW',
  1015. 1: 'HWC'
  1016. };
  1017. cntk.PoolKind = {
  1018. 0: 'None',
  1019. 1: 'Max',
  1020. 2: 'Average'
  1021. };
  1022. cntk.Error = class extends Error {
  1023. constructor(message) {
  1024. super(message);
  1025. this.name = 'Error loading CNTK model.';
  1026. }
  1027. };
  1028. export const ModelFactory = cntk.ModelFactory;