hickle.js 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. const hickle = {};
  2. hickle.ModelFactory = class {
  3. async match(context) {
  4. const group = await context.peek('hdf5');
  5. if (group && group.attributes && group.attributes.get('CLASS') === 'hickle') {
  6. return context.set('hickle', group);
  7. }
  8. return null;
  9. }
  10. async open(context) {
  11. return new hickle.Model(context.value);
  12. }
  13. };
  14. hickle.Model = class {
  15. constructor(group) {
  16. this.format = 'Hickle Weights';
  17. this.modules = [new hickle.Graph(group)];
  18. }
  19. };
  20. hickle.Graph = class {
  21. constructor(group) {
  22. this.inputs = [];
  23. this.outputs = [];
  24. const deserialize = (group) => {
  25. if (group && group.attributes.has('type')) {
  26. const type = group.attributes.get('type');
  27. if (Array.isArray(type) && type.length && typeof type[0] === 'string') {
  28. switch (type[0]) {
  29. case 'hickle':
  30. case 'dict_item': {
  31. if (group.groups.size === 1) {
  32. return deserialize(group.groups.values().next().value);
  33. }
  34. throw new hickle.Error(`Invalid Hickle type value '${type[0]}'.`);
  35. }
  36. case 'dict': {
  37. const dict = new Map();
  38. for (const [name, obj] of group.groups) {
  39. const value = deserialize(obj);
  40. dict.set(name, value);
  41. }
  42. return dict;
  43. }
  44. case 'ndarray': {
  45. return group.value;
  46. }
  47. default: {
  48. throw new hickle.Error(`Unsupported Hickle type '${type[0]}'`);
  49. }
  50. }
  51. }
  52. throw new hickle.Error(`Unsupported Hickle type '${JSON.stringify(type)}'`);
  53. }
  54. throw new hickle.Error('Unsupported Hickle group.');
  55. };
  56. const obj = deserialize(group);
  57. const layers = new Map();
  58. if (obj && obj instanceof Map && Array.from(obj.values()).every((value) => value.type && value.shape)) {
  59. for (const [key, value] of obj) {
  60. const tensor = new hickle.Tensor(key, value.shape, value.type, value.littleEndian, value.type === 'string' ? value.value : value.data);
  61. const bits = key.split('.');
  62. const parameter = bits.pop();
  63. const layer = bits.join('.');
  64. if (!layers.has(layer)) {
  65. layers.set(layer, []);
  66. }
  67. layers.get(layer).push({ name: parameter, value: tensor });
  68. }
  69. }
  70. this.nodes = Array.from(layers).map(([name, value]) => new hickle.Node(name, value));
  71. }
  72. };
  73. hickle.Argument = class {
  74. constructor(name, value) {
  75. this.name = name;
  76. this.value = value;
  77. }
  78. };
  79. hickle.Value = class {
  80. constructor(name, type, initializer = null) {
  81. if (typeof name !== 'string') {
  82. throw new hickle.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  83. }
  84. this.name = name;
  85. this.type = !type && initializer ? initializer.type : type;
  86. this.initializer = initializer;
  87. }
  88. };
  89. hickle.Node = class {
  90. constructor(name, parameters) {
  91. this.type = { name: 'Weights' };
  92. this.name = name;
  93. this.inputs = parameters.map((parameter) => {
  94. return new hickle.Argument(parameter.name, [
  95. new hickle.Value(parameter.value.name, null, parameter.value)
  96. ]);
  97. });
  98. this.outputs = [];
  99. this.attributes = [];
  100. }
  101. };
  102. hickle.Tensor = class {
  103. constructor(name, shape, type, littleEndian, data) {
  104. this.name = name;
  105. this.type = new hickle.TensorType(type, new hickle.TensorShape(shape));
  106. this.encoding = littleEndian ? '<' : '>';
  107. this._data = data;
  108. }
  109. get values() {
  110. if (Array.isArray(this._data) || this._data === null) {
  111. return null;
  112. }
  113. if (this._data instanceof Uint8Array) {
  114. return this._data;
  115. }
  116. return this._data.peek();
  117. }
  118. };
  119. hickle.TensorType = class {
  120. constructor(dataType, shape) {
  121. this.dataType = dataType;
  122. this.shape = shape;
  123. }
  124. toString() {
  125. return this.dataType + this.shape.toString();
  126. }
  127. };
  128. hickle.TensorShape = class {
  129. constructor(dimensions) {
  130. this.dimensions = dimensions;
  131. }
  132. toString() {
  133. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  134. }
  135. };
  136. hickle.Error = class extends Error {
  137. constructor(message) {
  138. super(message);
  139. this.name = 'Error loading Hickle model.';
  140. }
  141. };
  142. export const ModelFactory = hickle.ModelFactory;