transformers.js 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. // import * as python from './python.js';
  2. // import * as safetensors from './safetensors.js';
  3. const transformers = {};
  4. transformers.ModelFactory = class {
  5. async match(context) {
  6. const obj = await context.peek('json');
  7. if (obj) {
  8. if (obj.model_type && obj.architectures) {
  9. return context.set('transformers.config', obj);
  10. }
  11. if (obj.version && obj.added_tokens && obj.model) {
  12. return context.set('transformers.tokenizer', obj);
  13. }
  14. if (obj.tokenizer_class ||
  15. (obj.bos_token && obj.eos_token && obj.unk_token) ||
  16. (obj.pad_token && obj.additional_special_tokens) ||
  17. obj.special_tokens_map_file || obj.full_tokenizer_file) {
  18. return context.set('transformers.tokenizer.config', obj);
  19. }
  20. if (context.identifier === 'vocab.json' && Object.keys(obj).length > 256) {
  21. return context.set('transformers.vocab', obj);
  22. }
  23. }
  24. return null;
  25. }
  26. async open(context) {
  27. const fetch = async (name) => {
  28. try {
  29. const content = await context.fetch(name);
  30. await this.match(content);
  31. if (content.value) {
  32. return content;
  33. }
  34. } catch {
  35. // continue regardless of error
  36. }
  37. return null;
  38. };
  39. switch (context.type) {
  40. case 'transformers.config': {
  41. const tokenizer = await fetch('tokenizer.json');
  42. const tokenizer_config = await fetch('tokenizer_config.json');
  43. const vocab = await fetch('vocab.json');
  44. return new transformers.Model(context, tokenizer, tokenizer_config, vocab);
  45. }
  46. case 'transformers.tokenizer': {
  47. const config = await fetch('config.json');
  48. const tokenizer_config = await fetch('tokenizer_config.json');
  49. const vocab = await fetch('vocab.json');
  50. return new transformers.Model(config, context, tokenizer_config, vocab);
  51. }
  52. case 'transformers.tokenizer.config': {
  53. const config = await fetch('config.json');
  54. const tokenizer = await fetch('tokenizer.json');
  55. const vocab = await fetch('vocab.json');
  56. return new transformers.Model(config, tokenizer, context, vocab);
  57. }
  58. case 'transformers.vocab': {
  59. const config = await fetch('config.json');
  60. const tokenizer = await fetch('tokenizer.json');
  61. const tokenizer_config = await fetch('tokenizer_config.json');
  62. return new transformers.Model(config, tokenizer, tokenizer_config, context);
  63. }
  64. default: {
  65. throw new transformers.Error(`Unsupported Transformers format '${context.type}'.`);
  66. }
  67. }
  68. }
  69. filter(context, type) {
  70. return context.type !== 'transformers.config' || (type !== 'transformers.tokenizer' && type !== 'transformers.tokenizer.config' && type !== 'transformers.vocab' && type !== 'safetensors.json');
  71. }
  72. };
  73. transformers.Model = class {
  74. constructor(config, tokenizer, tokenizer_config, vocab) {
  75. this.format = 'Transformers';
  76. this.metadata = [];
  77. this.modules = [new transformers.Graph(config, tokenizer, tokenizer_config, vocab)];
  78. }
  79. };
  80. transformers.Graph = class {
  81. constructor(config, tokenizer, tokenizer_config, vocab) {
  82. this.type = 'graph';
  83. this.nodes = [];
  84. this.inputs = [];
  85. this.outputs = [];
  86. this.metadata = [];
  87. if (config) {
  88. for (const [key, value] of Object.entries(config.value)) {
  89. const argument = new transformers.Argument(key, value);
  90. this.metadata.push(argument);
  91. }
  92. }
  93. if (tokenizer || tokenizer_config) {
  94. const node = new transformers.Tokenizer(tokenizer, tokenizer_config, vocab);
  95. this.nodes.push(node);
  96. }
  97. }
  98. };
  99. transformers.Tokenizer = class {
  100. constructor(tokenizer, tokenizer_config) {
  101. this.type = { name: 'Tokenizer' };
  102. this.name = (tokenizer || tokenizer_config).identifier;
  103. this.attributes = [];
  104. if (tokenizer) {
  105. const obj = tokenizer.value;
  106. const keys = new Set(['decoder', 'model', 'post_processor', 'pre_tokenizer']);
  107. for (const [key, value] of Object.entries(tokenizer.value)) {
  108. if (!keys.has(key)) {
  109. const argument = new transformers.Argument(key, value);
  110. this.attributes.push(argument);
  111. }
  112. }
  113. for (const key of keys) {
  114. const value = obj[key];
  115. if (value) {
  116. const module = new transformers.Object(value);
  117. const argument = new transformers.Argument(key, module, 'object');
  118. this.attributes.push(argument);
  119. }
  120. }
  121. }
  122. }
  123. };
  124. transformers.Object = class {
  125. constructor(obj) {
  126. this.type = { name: obj.type };
  127. this.attributes = [];
  128. for (const [key, value] of Object.entries(obj)) {
  129. if (key !== 'type') {
  130. let argument = null;
  131. if (Array.isArray(value) && value.every((item) => typeof item === 'object')) {
  132. const values = value.map((item) => new transformers.Object(item));
  133. argument = new transformers.Argument(key, values, 'object[]');
  134. } else {
  135. argument = new transformers.Argument(key, value);
  136. }
  137. this.attributes.push(argument);
  138. }
  139. }
  140. }
  141. };
  142. transformers.Argument = class {
  143. constructor(name, value, type) {
  144. this.name = name;
  145. this.value = value;
  146. this.type = type || null;
  147. }
  148. };
  149. transformers.Error = class extends Error {
  150. constructor(message) {
  151. super(message);
  152. this.name = 'Error loading Transformers model.';
  153. }
  154. };
  155. export const ModelFactory = transformers.ModelFactory;