executorch.js 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. // Experimental
  2. const executorch = {};
  3. const coreml = {};
  4. const vulkan = {};
  5. const xnnpack = {};
  6. import * as base from './base.js';
  7. import * as python from './python.js';
  8. import * as pytorch from './pytorch.js';
  9. executorch.ModelFactory = class {
  10. async match(context) {
  11. const reader = await executorch.Reader.open(context);
  12. if (reader) {
  13. return context.set('executorch', reader);
  14. }
  15. return null;
  16. }
  17. async open(context) {
  18. executorch.schema = await context.require('./executorch-schema');
  19. const target = context.value;
  20. await target.read();
  21. return new executorch.Model(target);
  22. }
  23. };
  24. executorch.Model = class {
  25. constructor(target) {
  26. this.format = `ExecuTorch v${target.program.version}`;
  27. this.graphs = [];
  28. for (const plan of target.program.execution_plan) {
  29. for (const chain of plan.chains) {
  30. const graph = new executorch.Graph(target, plan, chain);
  31. this.graphs.push(graph);
  32. }
  33. }
  34. }
  35. };
  36. executorch.Graph = class {
  37. constructor(target, plan, chain) {
  38. this.name = plan.name || '';
  39. this.inputs = [];
  40. this.outputs = [];
  41. this.nodes = [];
  42. const values = new Map();
  43. values.tensors = (index, items) => {
  44. const list = [];
  45. for (let i = 0; i < items.length; i++) {
  46. const item = items[i];
  47. const type = item ? new executorch.TensorType(item) : null;
  48. let initializer = null;
  49. if (item && item.data_buffer_idx > 0) {
  50. initializer = new executorch.Tensor(item, target);
  51. }
  52. const identifier = items.length > 1 ? `${index}.${i}` : index.toString();
  53. const value = new executorch.Value(identifier, type, initializer);
  54. list.push(value);
  55. }
  56. return list;
  57. };
  58. values.map = (index, output) => {
  59. if (!values.has(index)) {
  60. const executorch_flatbuffer = executorch.schema.executorch_flatbuffer;
  61. const val = plan.values[index].val;
  62. const tensor = val instanceof executorch_flatbuffer.Tensor || val instanceof executorch_flatbuffer.TensorList || val instanceof executorch_flatbuffer.OptionalTensorList;
  63. if (output && !tensor) {
  64. const value = [new executorch.Value(index.toString(), null, null)];
  65. values.set(index, { type: null, value });
  66. } else if (val instanceof executorch_flatbuffer.Null) {
  67. values.set(index, { type: 'attribute', value: null });
  68. } else if (val instanceof executorch_flatbuffer.Int) {
  69. values.set(index, { type: 'int64', value: val.int_val });
  70. } else if (val instanceof executorch_flatbuffer.Bool) {
  71. values.set(index, { type: 'int64', value: val.bool_val });
  72. } else if (val instanceof executorch_flatbuffer.Double) {
  73. values.set(index, { type: 'float64', value: val.double_val });
  74. } else if (val instanceof executorch_flatbuffer.Tensor) {
  75. const items = [val];
  76. values.set(index, { type: null, value: values.tensors(index, items) });
  77. } else if (val instanceof executorch_flatbuffer.String) {
  78. values.set(index, { type: 'string', value: val.string_val });
  79. } else if (val instanceof executorch_flatbuffer.IntList) {
  80. const list = val.items.map((index) => plan.values[index].val.int_val);
  81. values.set(index, { type: 'int64[]', value: list });
  82. } else if (val instanceof executorch_flatbuffer.DoubleList) {
  83. throw new executorch.Error('executorch_flatbuffer.DoubleList not implemented.');
  84. } else if (val instanceof executorch_flatbuffer.BoolList) {
  85. throw new executorch.Error('executorch_flatbuffer.BoolList not implemented.');
  86. } else if (val instanceof executorch_flatbuffer.TensorList) {
  87. const items = Array.from(val.items).map((arg) => arg === -1 ? null : plan.values[arg].val);
  88. values.set(index, { type: null, value: values.tensors(index, items) });
  89. } else if (val instanceof executorch_flatbuffer.OptionalTensorList) {
  90. const items = Array.from(val.items).map((arg) => arg === -1 ? null : plan.values[arg].val);
  91. values.set(index, { type: null, value: values.tensors(index, items) });
  92. } else {
  93. throw new Error(`Value type '${val.constructor.name}' not implemented.`);
  94. }
  95. }
  96. return values.get(index);
  97. };
  98. for (let i = 0; i < plan.inputs.length; i++) {
  99. const input = plan.inputs[i];
  100. const value = values.map(input);
  101. const name = plan.inputs.length === 1 ? 'input' : `input.${i}`;
  102. const argument = new executorch.Argument(name, value.value, value.type);
  103. this.inputs.push(argument);
  104. }
  105. for (let i = 0; i < plan.outputs.length; i++) {
  106. const output = plan.outputs[i];
  107. const value = values.map(output);
  108. const name = plan.outputs.length === 1 ? 'output' : `output.${i}`;
  109. const argument = new executorch.Argument(name, value.value, value.type);
  110. this.outputs.push(argument);
  111. }
  112. for (const instruction of chain.instructions) {
  113. const node = new executorch.Node(target, plan, chain, instruction, values);
  114. this.nodes.push(node);
  115. }
  116. }
  117. };
  118. executorch.Argument = class {
  119. constructor(name, value, type, visible) {
  120. this.name = name;
  121. this.value = value;
  122. this.type = type || null;
  123. this.visible = visible !== false;
  124. }
  125. };
  126. executorch.Value = class Value {
  127. constructor(name, type, initializer) {
  128. if (typeof name !== 'string') {
  129. throw new executorch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  130. }
  131. this.name = name;
  132. this.type = initializer && initializer.type ? initializer.type : type || null;
  133. this.initializer = initializer || null;
  134. }
  135. };
  136. executorch.Node = class {
  137. constructor(target, plan, chain, instruction, values) {
  138. this.name = '';
  139. this.inputs = [];
  140. this.outputs = [];
  141. this.attributes = [];
  142. const instr_args = instruction.instr_args;
  143. const executorch_flatbuffer = executorch.schema.executorch_flatbuffer;
  144. if (instr_args instanceof executorch_flatbuffer.KernelCall) {
  145. const op = plan.operators[instr_args.op_index];
  146. const name = op.name.split('::').pop();
  147. const identifier = op.overload ? `${op.name}.${op.overload}` : op.name;
  148. const schemas = target.execution.invoke('torch._C._jit_get_schemas_for_operator', [op.name]);
  149. const schema = schemas.find((schema) => schema.name === op.name && schema.overload_name === op.overload);
  150. if (!schema) {
  151. throw new executorch.Error(`Operator schema for '${identifier}' not found.`);
  152. }
  153. const category = schema && schema.category ? schema.category : '';
  154. const alias = (arg) => arg && arg.alias_info && arg.alias_info.before_set.length === 1 ? arg.alias_info.before_set[0] : null;
  155. const outputs = new Set(schema && Array.isArray(schema.returns) ? schema.returns.map((arg) => alias(arg)).filter((alias) => alias !== null) : []);
  156. const inputs = new Map();
  157. this.type = { name, identifier, category };
  158. let i = 0;
  159. const args = instr_args.args;
  160. for (; i < schema.arguments.length; i++) {
  161. const index = args[i];
  162. const arg = schema && i < schema.arguments.length ? schema.arguments[i] : null;
  163. const output = arg ? alias(schema.arguments[i]) : null;
  164. if (output && outputs.has(output)) {
  165. inputs.set(output, index);
  166. continue;
  167. }
  168. const name = arg ? arg.name : i.toString();
  169. const value = values.map(index);
  170. const argument = new executorch.Argument(name, value.value, value.type);
  171. this.inputs.push(argument);
  172. }
  173. for (let j = 0; j < schema.returns.length; j++) {
  174. const ret = schema.returns[j];
  175. const output = alias(ret);
  176. let index = args[i++];
  177. index = output && inputs.has(output) ? inputs.get(output) : index;
  178. const name = ret.name;
  179. const value = values.map(index, true);
  180. const argument = new executorch.Argument(name || '', value.value, value.type);
  181. this.outputs.push(argument);
  182. }
  183. } else if (instr_args instanceof executorch_flatbuffer.DelegateCall) {
  184. const delegate = plan.delegates[instr_args.delegate_index];
  185. const args = instr_args.args;
  186. if (!delegate.backend || !delegate.backend.type) {
  187. throw new executorch.Error(`ExecuTorch delegate '${delegate.id}' not implemented.`);
  188. }
  189. this.type = delegate.backend.type;
  190. const inputs = args.slice(0, this.type.inputs.length);
  191. for (let i = 0; i < inputs.length; i++) {
  192. const input = inputs[i];
  193. const value = values.map(input);
  194. const name = inputs.length === 1 ? 'input' : `input.${i}`;
  195. const argument = new executorch.Argument(name, value.value, value.type);
  196. this.inputs.push(argument);
  197. }
  198. const outputs = args.slice(this.type.inputs.length, this.type.inputs.length + this.type.outputs.length);
  199. for (let i = 0; i < outputs.length; i++) {
  200. const output = outputs[i];
  201. const value = values.map(output);
  202. const name = inputs.length === 1 ? 'output' : `output.${i}`;
  203. const argument = new executorch.Argument(name, value.value, value.type);
  204. this.outputs.push(argument);
  205. }
  206. for (const spec of delegate.compile_specs) {
  207. const value = spec.value instanceof Uint8Array ? new TextDecoder('utf-8').decode(spec.value) : spec.value;
  208. const attribute = new executorch.Argument(spec.key, value, 'attribute');
  209. this.attributes.push(attribute);
  210. }
  211. } else {
  212. throw new Error(`Instruction type '${instr_args.constructor.name}' not implemented.`);
  213. }
  214. }
  215. };
  216. executorch.TensorType = class {
  217. constructor(tensor) {
  218. executorch.TensorType._types = executorch.TensorType._types || [
  219. 'uint8',
  220. 'int8', 'int16', 'int32', 'int64',
  221. 'float16', 'float32', 'float64',
  222. 'complex16', 'complex32', 'complex64',
  223. 'boolean',
  224. 'qint8', 'quint8', 'qint32',
  225. 'bfloat16',
  226. 'quint4x2', 'quint2x4', 'bits1x8', 'bits2x4', 'bits4x2', 'bits8', 'bits16',
  227. 'float8e5m2', 'float8e4m3fn', 'float8e5m2fnuz', 'float8e4m3fnuz',
  228. 'uint16', 'uint32', 'uint64'
  229. ];
  230. if (tensor.scalar_type >= executorch.TensorType._types.length) {
  231. throw new executorch.Error(`Unknown tensor data type '${tensor.scalar_type}'.`);
  232. }
  233. this.dataType = executorch.TensorType._types[tensor.scalar_type];
  234. this.shape = new executorch.TensorShape(Array.from(tensor.sizes));
  235. }
  236. toString() {
  237. return this.dataType + this.shape.toString();
  238. }
  239. };
  240. executorch.TensorShape = class {
  241. constructor(dimensions) {
  242. this.dimensions = dimensions || [];
  243. }
  244. toString() {
  245. if (this.dimensions && this.dimensions.length > 0) {
  246. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  247. }
  248. return '';
  249. }
  250. };
  251. executorch.Tensor = class {
  252. constructor(tensor, target) {
  253. this.type = new executorch.TensorType(tensor);
  254. const data_buffer_idx = tensor.data_buffer_idx;
  255. const program = target.program;
  256. if (tensor.extra_tensor_info) {
  257. throw new executorch.Error('Extra tensor info not implemented.');
  258. } else if (program.constant_buffers) {
  259. throw new executorch.Error('Constant buffers not implemented.');
  260. } else if (tensor.allocation_info === null) {
  261. const constant_segment = program.constant_segment;
  262. const data_segment = program.segments[constant_segment.segment_index];
  263. const offset = constant_segment.offsets[data_buffer_idx].toNumber();
  264. const next = data_buffer_idx + 1 < constant_segment.offsets.length ? constant_segment.offsets[data_buffer_idx + 1].toNumber() : data_segment.size.toNumber();
  265. const size = next - offset;
  266. this.values = target.blob(data_segment.offset.toNumber() + offset, size);
  267. this.encoding = '<';
  268. } else {
  269. throw new executorch.Error('Tensor allocation info not implemented.');
  270. }
  271. }
  272. };
  273. executorch.Reader = class {
  274. static async open(context) {
  275. const reader = await context.peek('flatbuffers.binary');
  276. if (reader && reader.identifier === 'ET12') {
  277. return new executorch.Reader(context, reader);
  278. }
  279. return null;
  280. }
  281. constructor(context, reader) {
  282. this.context = context;
  283. this.reader = reader;
  284. }
  285. async read() {
  286. const context = this.context;
  287. this.metadata = await pytorch.Metadata.open(context);
  288. this.execution = new python.Execution();
  289. this.metadata.register(this.execution);
  290. const executorch_flatbuffer = executorch.schema.executorch_flatbuffer;
  291. this.program = executorch_flatbuffer.Program.create(this.reader);
  292. this.reader = await context.read('binary');
  293. if (this.reader.length >= 32) {
  294. this.reader.seek(8);
  295. const magic = String.fromCharCode(...this.reader.read(4));
  296. if (magic === 'eh00') {
  297. this.extended_file_header = {
  298. length: this.reader.uint32(),
  299. program_size: this.reader.uint64().toNumber(),
  300. segment_base_offset: this.reader.uint64().toNumber(),
  301. };
  302. }
  303. this.reader.seek(0);
  304. }
  305. for (const plan of this.program.execution_plan) {
  306. for (const chain of plan.chains) {
  307. for (const instruction of chain.instructions) {
  308. const instr_args = instruction.instr_args;
  309. if (instr_args instanceof executorch_flatbuffer.DelegateCall) {
  310. const delegate = plan.delegates[instr_args.delegate_index];
  311. if (delegate.backend) {
  312. continue;
  313. }
  314. let data = null;
  315. switch (delegate.processed.location) {
  316. case executorch_flatbuffer.DataLocation.INLINE: {
  317. data = this.program.backend_delegate_data[delegate.processed.index].data;
  318. break;
  319. }
  320. case executorch_flatbuffer.DataLocation.SEGMENT: {
  321. const segment = this.program.segments[delegate.processed.index];
  322. data = this.blob(segment.offset.toNumber(), segment.size.toNumber());
  323. break;
  324. }
  325. default: {
  326. throw new executorch.Error(`Delegate data location '${delegate.processed.location}' not implemented.`);
  327. }
  328. }
  329. switch (delegate.id) {
  330. case 'XnnpackBackend': {
  331. delegate.backend = xnnpack.Reader.open(data, this);
  332. break;
  333. }
  334. case 'CoreMLBackend': {
  335. delegate.backend = coreml.Reader.open(data, this);
  336. break;
  337. }
  338. case 'VulkanBackend': {
  339. delegate.backend = vulkan.Reader.open(data, this);
  340. break;
  341. }
  342. default: {
  343. throw new executorch.Error(`ExecuTorch delegate '${delegate.id}' not implemented.`);
  344. }
  345. }
  346. /* eslint-disable no-await-in-loop */
  347. await delegate.backend.read();
  348. /* eslint-enable no-await-in-loop */
  349. }
  350. }
  351. }
  352. }
  353. }
  354. blob(offset, size) {
  355. if (this.extended_file_header) {
  356. this.reader.seek(this.extended_file_header.segment_base_offset + offset);
  357. const data = this.reader.read(size);
  358. this.reader.seek(0);
  359. return data;
  360. }
  361. return null;
  362. }
  363. };
  364. executorch.Error = class extends Error {
  365. constructor(message) {
  366. super(message);
  367. this.name = 'Error loading ExecuTorch model.';
  368. }
  369. };
  370. xnnpack.Reader = class {
  371. static open(data, target) {
  372. if (data.length >= 30) {
  373. const reader = base.BinaryReader.open(data);
  374. reader.skip(4);
  375. const magic = String.fromCharCode(...reader.read(4));
  376. if (magic === 'XH00') {
  377. return new xnnpack.Reader(reader, target);
  378. }
  379. }
  380. return null;
  381. }
  382. constructor(reader, target) {
  383. this.reader = reader;
  384. this.target = target;
  385. reader.skip(2);
  386. this.flatbuffer = {
  387. offset: reader.uint32(),
  388. size: reader.uint32(),
  389. };
  390. this.constants = {
  391. offset: reader.uint32(),
  392. size: reader.uint32(),
  393. };
  394. }
  395. async read() {
  396. this.reader.seek(this.flatbuffer.offset);
  397. const flatbuffers = await import('./flatbuffers.js');
  398. const data = this.reader.read(this.flatbuffer.size);
  399. const reader = flatbuffers.BinaryReader.open(data);
  400. if (!executorch.schema.fb_xnnpack.XNNGraph.identifier(reader)) {
  401. throw new xnnpack.Error('Invalid XNNPACK data.');
  402. }
  403. this.graph = executorch.schema.fb_xnnpack.XNNGraph.create(reader);
  404. this.reader.seek(0);
  405. const metadata = new xnnpack.Metadata();
  406. this.type = new xnnpack.Graph(metadata, this.graph, this);
  407. }
  408. constant(idx) {
  409. const constant_data = this.graph.constant_data[idx];
  410. this.reader.seek(this.constants.offset + constant_data.offset.toNumber());
  411. const data = this.reader.read(constant_data.size.toNumber());
  412. this.reader.seek(0);
  413. return data;
  414. }
  415. };
  416. xnnpack.Graph = class {
  417. constructor(metadata, graph, reader) {
  418. this.name = 'XnnpackBackend';
  419. this.type = 'graph';
  420. this.inputs = [];
  421. this.outputs = [];
  422. this.nodes = [];
  423. const values = new Map();
  424. values.map = (id) => {
  425. if (!values.has(id)) {
  426. const fb_xnnpack = executorch.schema.fb_xnnpack;
  427. const name = id.toString();
  428. const xvalue = graph.xvalues[id].xvalue_union;
  429. if (xvalue instanceof fb_xnnpack.XNNTensorValue) {
  430. const type = new xnnpack.TensorType(xvalue);
  431. const initializer = xvalue.constant_buffer_idx === 0 ? null : new xnnpack.Tensor(xvalue, reader);
  432. const value = new xnnpack.Value(name, type, initializer);
  433. values.set(id, value);
  434. } else if (xvalue instanceof fb_xnnpack.XNNQuantizedTensorValue) {
  435. const value = new xnnpack.Value(name, null, null);
  436. values.set(id, value);
  437. } else {
  438. throw new xnnpack.Error(`Value type '${xvalue.constructor.name}' not implemented.`);
  439. }
  440. }
  441. return values.get(id);
  442. };
  443. for (let i = 0; i < graph.input_ids.length; i++) {
  444. const id = graph.input_ids[i];
  445. const value = values.map(id);
  446. const name = graph.input_ids.length === 1 ? 'input' : `input.${i}`;
  447. const argument = new xnnpack.Argument(name, [value]);
  448. this.inputs.push(argument);
  449. }
  450. for (let i = 0; i < graph.output_ids.length; i++) {
  451. const id = graph.output_ids[i];
  452. const value = values.map(id);
  453. const name = graph.output_ids.length === 1 ? 'output' : `output.${i}`;
  454. const argument = new xnnpack.Argument(name, [value]);
  455. this.outputs.push(argument);
  456. }
  457. for (const xnode of graph.xnodes) {
  458. const node = new xnnpack.Node(metadata, xnode, values);
  459. this.nodes.push(node);
  460. }
  461. }
  462. };
  463. xnnpack.Node = class {
  464. constructor(metadata, xnode, values) {
  465. const node = xnode.xnode_union;
  466. this.type = metadata.type(node.constructor.name) || { name: node.constructor.name };
  467. this.name = '';
  468. this.inputs = [];
  469. this.outputs = [];
  470. for (const [name, obj] of Object.entries(node)) {
  471. let value = ArrayBuffer.isView(obj) ? Array.from(obj) : obj;
  472. let type = 'attribute';
  473. if (name.endsWith('_id')) {
  474. value = obj === -1 ? [] : [values.map(obj)];
  475. type = null;
  476. }
  477. const argument = new xnnpack.Argument(name, value, type);
  478. if (name === 'output_id') {
  479. this.outputs.push(argument);
  480. } else {
  481. this.inputs.push(argument);
  482. }
  483. }
  484. }
  485. };
  486. xnnpack.Argument = class {
  487. constructor(name, value, type, visible) {
  488. this.name = name;
  489. this.value = value;
  490. this.type = type || null;
  491. this.visible = visible !== false;
  492. }
  493. };
  494. xnnpack.Value = class Value {
  495. constructor(name, type, initializer) {
  496. if (typeof name !== 'string') {
  497. throw new executorch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  498. }
  499. this.name = name;
  500. this.type = initializer && initializer.type ? initializer.type : type || null;
  501. this.initializer = initializer || null;
  502. }
  503. };
  504. xnnpack.Metadata = class {
  505. constructor() {
  506. this._types = new Map();
  507. this.register('_XNNCat', 'Tensor');
  508. this.register('_XNNNodeConv', 'Layer');
  509. this.register('XNNArgMaxPooling2d', 'Pool');
  510. this.register('XNNAvgPooling2d', 'Pool');
  511. this.register('XNNCeiling', 'Activation');
  512. this.register('XNNConcatenate2', 'Tensor');
  513. this.register('XNNConcatenate3', 'Tensor');
  514. this.register('XNNConcatenate4', 'Tensor');
  515. this.register('XNNConcatenate5', 'Tensor');
  516. this.register('XNNConv2d', 'Layer');
  517. this.register('XNNConvTranspose2d', 'Layer');
  518. this.register('XNNDepthwiseConv2d', 'Layer');
  519. this.register('XNNELU', 'Activation');
  520. this.register('XNNFullyConnected', 'Layer');
  521. this.register('XNNGelu', 'Activation');
  522. this.register('XNNGlobalAvgPooling2d', 'Pool');
  523. this.register('XNNGlobalAvgPooling2d', 'Pool');
  524. this.register('XNNHardswish', 'Activation');
  525. this.register('XNNLeakyReLU', 'Activation');
  526. this.register('XNNMaxPooling2d', 'Pool');
  527. this.register('XNNPReLU', 'Activation');
  528. this.register('XNNSigmoid', 'Activation');
  529. this.register('XNNSoftmax', 'Activation');
  530. this.register('XNNStaticTranspose', 'Transform');
  531. }
  532. register(name, category) {
  533. this._types.set(name, { name, category });
  534. }
  535. type(name) {
  536. return this._types.get(name);
  537. }
  538. };
  539. xnnpack.TensorType = class {
  540. constructor(tensor) {
  541. xnnpack.TensorType._types = executorch.TensorType._types || [
  542. 'invalid', 'float32', 'float16',
  543. 'qint8', 'quint8', 'qint32',
  544. 'qcint8', 'qcint32', 'qcint4',
  545. 'qdint8', 'qbint4'
  546. ];
  547. if (tensor.datatype >= xnnpack.TensorType._types.length) {
  548. throw new xnnpack.Error(`Unknown tensor data type '${tensor.datatype}'.`);
  549. }
  550. this.dataType = xnnpack.TensorType._types[tensor.datatype];
  551. this.shape = new xnnpack.TensorShape(Array.from(tensor.dims));
  552. }
  553. toString() {
  554. return this.dataType + this.shape.toString();
  555. }
  556. };
  557. xnnpack.TensorShape = class {
  558. constructor(dimensions) {
  559. this.dimensions = dimensions || [];
  560. }
  561. toString() {
  562. if (this.dimensions && this.dimensions.length > 0) {
  563. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  564. }
  565. return '';
  566. }
  567. };
  568. xnnpack.Tensor = class {
  569. constructor(tensor, reader) {
  570. this.type = new xnnpack.TensorType(tensor);
  571. this.values = reader.constant(tensor.constant_buffer_idx);
  572. this.encoding = '<';
  573. }
  574. };
  575. xnnpack.Error = class extends Error {
  576. constructor(message) {
  577. super(message);
  578. this.name = 'Error loading XNNPACK model.';
  579. }
  580. };
  581. vulkan.Reader = class {
  582. static open(data, target) {
  583. if (data.length >= 30) {
  584. const reader = base.BinaryReader.open(data);
  585. reader.skip(4);
  586. const magic = String.fromCharCode(...reader.read(4));
  587. if (magic === 'VH00') {
  588. return new vulkan.Reader(reader, target);
  589. }
  590. }
  591. return null;
  592. }
  593. constructor(reader, target) {
  594. this.reader = reader;
  595. this.target = target;
  596. reader.skip(2);
  597. this.flatbuffer = {
  598. offset: reader.uint32(),
  599. size: reader.uint32(),
  600. };
  601. this.constants = {
  602. offset: reader.uint32(),
  603. size: reader.uint32(),
  604. };
  605. }
  606. async read() {
  607. this.reader.seek(this.flatbuffer.offset);
  608. const metadata = new vulkan.Metadata(this.target.execution);
  609. metadata.register('conv_with_clamp(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor)');
  610. const flatbuffers = await import('./flatbuffers.js');
  611. const data = this.reader.read(this.flatbuffer.size);
  612. const reader = flatbuffers.BinaryReader.open(data);
  613. if (!executorch.schema.vkgraph.VkGraph.identifier(reader)) {
  614. throw new xnnpack.Error('Invalid Vuklan data.');
  615. }
  616. this.graph = executorch.schema.vkgraph.VkGraph.create(reader);
  617. this.reader.seek(0);
  618. this.type = new vulkan.Graph(metadata, this.graph, this);
  619. }
  620. constant(id) {
  621. const constant = this.graph.constants[id];
  622. this.reader.seek(this.constants.offset + constant.offset.toNumber());
  623. const data = this.reader.read(constant.length.toNumber());
  624. this.reader.seek(0);
  625. return data;
  626. }
  627. };
  628. vulkan.Graph = class {
  629. constructor(metadata, graph, reader) {
  630. this.name = 'VulkanBackend';
  631. this.inputs = [];
  632. this.outputs = [];
  633. this.nodes = [];
  634. const values = new Map();
  635. values.map = (id) => {
  636. if (!values.has(id)) {
  637. const vkgraph = executorch.schema.vkgraph;
  638. const arg = graph.values[id].value;
  639. if (arg instanceof vkgraph.VkTensor) {
  640. const type = new vulkan.TensorType(arg);
  641. const initializer = arg.constant_id === -1 ? null : new vulkan.Tensor(arg, reader);
  642. const value = new vulkan.Value(id.toString(), type, initializer);
  643. values.set(id, { type: null, value: [value] });
  644. } else if (arg instanceof vkgraph.Int) {
  645. values.set(id, { type: 'int64', value: arg.int_val });
  646. } else if (arg instanceof vkgraph.IntList) {
  647. values.set(id, { type: 'int64[]', value: Array.from(arg.items) });
  648. } else if (arg instanceof vkgraph.Double) {
  649. values.set(id, { type: 'float64', value: arg.double_val });
  650. } else if (arg instanceof vkgraph.Bool) {
  651. values.set(id, { type: 'boolean', value: arg.bool_val });
  652. } else if (arg instanceof vkgraph.Null) {
  653. values.set(id, { type: 'attribute', value: null });
  654. } else {
  655. throw new Error(`Value type '${arg.constructor.name}' not implemented.`);
  656. }
  657. }
  658. return values.get(id);
  659. };
  660. for (let i = 0; i < graph.input_ids.length; i++) {
  661. const id = graph.input_ids[i];
  662. const value = values.map(id);
  663. const name = graph.input_ids.length === 1 ? 'input' : `input.${i}`;
  664. const argument = new vulkan.Argument(name, value.value, value.type);
  665. this.inputs.push(argument);
  666. }
  667. for (let i = 0; i < graph.output_ids.length; i++) {
  668. const id = graph.output_ids[i];
  669. const value = values.map(id);
  670. const name = graph.output_ids.length === 1 ? 'output' : `output.${i}`;
  671. const argument = new vulkan.Argument(name, value.value, value.type);
  672. this.outputs.push(argument);
  673. }
  674. for (const op of graph.chain) {
  675. const node = new vulkan.Node(metadata, op, values);
  676. this.nodes.push(node);
  677. }
  678. }
  679. };
  680. vulkan.Node = class {
  681. constructor(metadata, op, values) {
  682. const schema = metadata.type(op.name);
  683. if (!schema) {
  684. throw new vulkan.Error(`Operator schema for '${op.name}' not found.`);
  685. }
  686. this.type = {
  687. name: op.name.split(/\.([^.]*)$/)[0],
  688. identifier: op.name,
  689. category: schema.category || ''
  690. };
  691. this.name = op.node_id.toString();
  692. this.inputs = [];
  693. this.outputs = [];
  694. this.attributes = [];
  695. for (let i = 0; i < op.args.length; i++) {
  696. const arg = op.args[i];
  697. const input = schema && i < schema.arguments.length;
  698. const def = input ? schema.arguments[i] : schema.returns[i - schema.arguments.length];
  699. const value = values.map(arg);
  700. const argument = new vulkan.Argument(def.name || '', value.value, value.type);
  701. if (input) {
  702. this.inputs.push(argument);
  703. } else {
  704. this.outputs.push(argument);
  705. }
  706. }
  707. }
  708. };
  709. vulkan.Argument = class {
  710. constructor(name, value, type, visible) {
  711. this.name = name;
  712. this.value = value;
  713. this.type = type || null;
  714. this.visible = visible !== false;
  715. }
  716. };
  717. vulkan.Value = class Value {
  718. constructor(name, type, initializer) {
  719. if (typeof name !== 'string') {
  720. throw new executorch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  721. }
  722. this.name = name;
  723. this.type = initializer && initializer.type ? initializer.type : type || null;
  724. this.initializer = initializer || null;
  725. }
  726. };
  727. vulkan.TensorType = class {
  728. constructor(tensor) {
  729. const types = ['bool', 'uint8', 'int8', 'int32', 'float16', 'float32'];
  730. if (tensor.datatype >= types.length) {
  731. throw new vulkan.Error(`Unknown tensor data type '${tensor.datatype}'.`);
  732. }
  733. this.dataType = types[tensor.datatype];
  734. this.shape = new vulkan.TensorShape(Array.from(tensor.dims));
  735. }
  736. toString() {
  737. return this.dataType + this.shape.toString();
  738. }
  739. };
  740. vulkan.TensorShape = class {
  741. constructor(dimensions) {
  742. this.dimensions = dimensions || [];
  743. }
  744. toString() {
  745. if (this.dimensions && this.dimensions.length > 0) {
  746. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  747. }
  748. return '';
  749. }
  750. };
  751. vulkan.Tensor = class {
  752. constructor(tensor, reader) {
  753. this.type = new vulkan.TensorType(tensor);
  754. this.values = reader.constant(tensor.constant_id);
  755. this.encoding = '<';
  756. }
  757. };
  758. vulkan.Metadata = class {
  759. constructor(execution) {
  760. this.execution = execution;
  761. }
  762. register(signature) {
  763. const torch = this.execution.register('torch');
  764. const registry = torch._C.getRegistry();
  765. const schema = torch.FunctionSchema.parse(signature);
  766. const op = new torch._C.Operator(schema);
  767. registry.registerOperator(op);
  768. }
  769. type(identifier) {
  770. identifier = identifier.split(/\.([^.]*)$/);
  771. const name = identifier[0].replace('.', '::');
  772. const overload = identifier[1] === 'default' ? '' : identifier[1];
  773. const schemas = this.execution.invoke('torch._C._jit_get_schemas_for_operator', [name]);
  774. const schema = schemas.find((schema) => schema.name === name && schema.overload_name === overload);
  775. return schema;
  776. }
  777. };
  778. vulkan.Error = class extends Error {
  779. constructor(message) {
  780. super(message);
  781. this.name = 'Error loading Vulkan model.';
  782. }
  783. };
  784. coreml.Reader = class {
  785. static open(data, target) {
  786. const reader = base.BinaryReader.open(data);
  787. return new coreml.Reader(reader, target);
  788. }
  789. constructor(reader, target) {
  790. this.reader = reader;
  791. this.target = target;
  792. }
  793. async factory() {
  794. const coreml = await import('./coreml.js');
  795. return new coreml.ModelFactory();
  796. }
  797. async read() {
  798. const entries = this.entries(this.reader);
  799. const factory = await this.factory();
  800. const protobuf = await import('./protobuf.js');
  801. for (const [key, value] of entries) {
  802. const path = key.split('/');
  803. const identifier = path.pop();
  804. const folder = path.length === 0 ? '' : `${path.join('/')}/`;
  805. const locals = new Map(Array.from(entries).filter(([key]) => key.startsWith(folder)).map(([key, value]) => [key.substring(folder.length), value]));
  806. const context = new coreml.Context(this, identifier, value, locals, protobuf);
  807. /* eslint-disable no-await-in-loop */
  808. const type = await factory.match(context);
  809. /* eslint-enable no-await-in-loop */
  810. if (type === 'coreml.manifest') {
  811. /* eslint-disable no-await-in-loop */
  812. const model = await factory.open(context);
  813. /* eslint-enable no-await-in-loop */
  814. [this.type] = model.graphs;
  815. this.type.name = 'CoreMLBackend';
  816. return;
  817. }
  818. }
  819. }
  820. stream(offset, size) {
  821. this.reader.seek(offset);
  822. const stream = this.reader.stream(size);
  823. this.reader.seek(0);
  824. return stream;
  825. }
  826. entries(reader) {
  827. const files = new Map();
  828. reader.seek(reader.length - 1);
  829. const str = [];
  830. let depth = 0;
  831. do {
  832. const c = String.fromCharCode(reader.byte());
  833. reader.skip(-2);
  834. if (c === '{') {
  835. depth++;
  836. } else if (c === '}') {
  837. depth--;
  838. }
  839. str.push(c);
  840. } while (depth > 0);
  841. const metadata = JSON.parse(str.join(''));
  842. const nodes = metadata.nodes;
  843. const roots = Array.from(nodes);
  844. for (const root of roots) {
  845. if (root !== null) {
  846. for (const index of Object.values(root.children)) {
  847. roots[index] = null;
  848. }
  849. }
  850. }
  851. const process = (path, node) => {
  852. path = path ? `${path}/${node.name}` : node.name;
  853. if (node.kind === 0) {
  854. files.set(path, node.dataRegion);
  855. } else if (node.kind === 1) {
  856. for (const index of Object.values(node.children)) {
  857. process(path, nodes[index]);
  858. }
  859. } else {
  860. throw new Error(`Node kind '${node.kind}' not implemented.`);
  861. }
  862. };
  863. for (const root of roots.filter((node) => node !== null)) {
  864. process('', root);
  865. }
  866. return files;
  867. }
  868. };
  869. coreml.Context = class {
  870. constructor(reader, identifier, location, entries, protobuf) {
  871. this._reader = reader;
  872. this._location = location;
  873. this._identifier = identifier;
  874. this._entries = entries;
  875. this._protobuf = protobuf;
  876. }
  877. get identifier() {
  878. return this._identifier;
  879. }
  880. get stream() {
  881. if (!this._stream) {
  882. this._stream = this._reader.stream(this._location.offset, this._location.size);
  883. }
  884. return this._stream;
  885. }
  886. async tags(type) {
  887. if (type === 'pb' && this.identifier.endsWith('.mlmodel')) {
  888. return new Map([[1,0],[2,2]]);
  889. }
  890. return new Map();
  891. }
  892. async peek(type) {
  893. if (type === 'json') {
  894. const data = this.stream.peek();
  895. const decoder = new TextDecoder('utf-8');
  896. const text = decoder.decode(data);
  897. return JSON.parse(text);
  898. }
  899. return null;
  900. }
  901. async read(type) {
  902. if (type === 'protobuf.binary') {
  903. return this._protobuf.BinaryReader.open(this.stream);
  904. }
  905. return null;
  906. }
  907. async fetch(file) {
  908. if (this._entries.has(file)) {
  909. const location = this._entries.get(file);
  910. const identifier = file.split('/').pop();
  911. return new coreml.Context(this._reader, identifier, location, this._entries, this._protobuf);
  912. }
  913. return null;
  914. }
  915. async require(id) {
  916. return this._reader.target.context.require(id);
  917. }
  918. async metadata(name) {
  919. return this._reader.target.context.metadata(name);
  920. }
  921. set(type, value) {
  922. this.type = type;
  923. this.value = value;
  924. return type;
  925. }
  926. };
  927. export const ModelFactory = executorch.ModelFactory;