worker.js 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. import * as base from '../source/base.js';
  2. import * as fs from 'fs/promises';
  3. import * as mock from './mock.js';
  4. import * as node from '../source/node.js';
  5. import * as path from 'path';
  6. import * as process from 'process';
  7. import * as python from '../source/python.js';
  8. import * as tar from '../source/tar.js';
  9. import * as url from 'url';
  10. import * as view from '../source/view.js';
  11. import * as worker_threads from 'worker_threads';
  12. import * as zip from '../source/zip.js';
  13. const access = async (path) => {
  14. try {
  15. await fs.access(path);
  16. return true;
  17. } catch {
  18. return false;
  19. }
  20. };
  21. const dirname = (...args) => {
  22. const file = url.fileURLToPath(import.meta.url);
  23. const dir = path.dirname(file);
  24. return path.join(dir, ...args);
  25. };
  26. const decompress = (buffer) => {
  27. let archive = zip.Archive.open(buffer, 'gzip');
  28. if (archive && archive.entries.size === 1) {
  29. const stream = archive.entries.values().next().value;
  30. buffer = stream.peek();
  31. }
  32. const formats = [zip, tar];
  33. for (const module of formats) {
  34. archive = module.Archive.open(buffer);
  35. if (archive) {
  36. break;
  37. }
  38. }
  39. return archive;
  40. };
  41. export class Target {
  42. constructor(item) {
  43. Object.assign(this, item);
  44. this.events = {};
  45. this.tags = new Set(this.tags);
  46. this.folder = item.type ? path.normalize(dirname('..', 'third_party' , 'test', item.type)) : process.cwd();
  47. this.assert = !this.assert || Array.isArray(this.assert) ? this.assert : [this.assert];
  48. this.serial = false;
  49. }
  50. on(event, callback) {
  51. this.events[event] = this.events[event] || [];
  52. this.events[event].push(callback);
  53. }
  54. emit(event, data) {
  55. if (this.events && this.events[event]) {
  56. for (const callback of this.events[event]) {
  57. callback(this, data);
  58. }
  59. }
  60. }
  61. status(message) {
  62. this.emit('status', message);
  63. }
  64. async execute() {
  65. if (this.measures) {
  66. this.measures.set('name', this.name);
  67. }
  68. await zip.Archive.import();
  69. const environment = { zoom: 'none', serial: this.serial };
  70. this.host = await new mock.Host(environment);
  71. this.view = new view.View(this.host);
  72. this.view.options.attributes = true;
  73. this.view.options.initializers = true;
  74. const time = async (method) => {
  75. const start = process.hrtime.bigint();
  76. let err = null;
  77. try {
  78. await method.call(this);
  79. } catch (error) {
  80. err = error;
  81. }
  82. const duration = Number(process.hrtime.bigint() - start) / 1e9;
  83. if (this.measures) {
  84. this.measures.set(method.name, duration);
  85. }
  86. if (err) {
  87. throw err;
  88. }
  89. };
  90. this.status({ name: 'name', target: this.name });
  91. const errors = [];
  92. try {
  93. await time(this.download);
  94. await time(this.load);
  95. await time(this.validate);
  96. if (!this.tags.has('skip-render')) {
  97. await time(this.render);
  98. }
  99. } catch (error) {
  100. errors.push(error);
  101. }
  102. errors.push(...this.host.errors);
  103. if (errors.length === 0 && this.error) {
  104. throw new Error('Expected error.');
  105. }
  106. if (errors.length > 0 && (!this.error || errors.map((error) => error.message).join('\n') !== this.error)) {
  107. throw errors[0];
  108. }
  109. this.view.dispose();
  110. }
  111. async request(url, init) {
  112. const response = await fetch(url, init);
  113. if (!response.ok) {
  114. throw new Error(response.status.toString());
  115. }
  116. if (response.body) {
  117. const reader = response.body.getReader();
  118. const length = response.headers.has('Content-Length') ? parseInt(response.headers.get('Content-Length'), 10) : -1;
  119. let position = 0;
  120. /* eslint-disable consistent-this */
  121. const target = this;
  122. /* eslint-enable consistent-this */
  123. const stream = new ReadableStream({
  124. async start(controller) {
  125. const read = async () => {
  126. try {
  127. const result = await reader.read();
  128. if (result.done) {
  129. target.status({ name: 'download' });
  130. controller.close();
  131. } else {
  132. position += result.value.length;
  133. if (length >= 0) {
  134. const percent = position / length;
  135. target.status({ name: 'download', target: url, percent });
  136. } else {
  137. target.status({ name: 'download', target: url, position });
  138. }
  139. controller.enqueue(result.value);
  140. return await read();
  141. }
  142. } catch (error) {
  143. controller.error(error);
  144. throw error;
  145. }
  146. return null;
  147. };
  148. return read();
  149. }
  150. });
  151. return new Response(stream, {
  152. status: response.status,
  153. statusText: response.statusText,
  154. headers: response.headers
  155. });
  156. }
  157. return response;
  158. }
  159. async download(targets, sources) {
  160. targets = targets || Array.from(this.targets);
  161. sources = sources || this.source;
  162. const files = targets.map((file) => path.resolve(this.folder, file));
  163. const exists = await Promise.all(files.map((file) => access(file)));
  164. if (exists.every((value) => value)) {
  165. return;
  166. }
  167. if (!sources) {
  168. throw new Error('Download source not specified.');
  169. }
  170. let source = '';
  171. let sourceFiles = [];
  172. const match = sources.match(/^(.*?)\[(.*?)\](.*)$/);
  173. if (match) {
  174. [, source, sourceFiles, sources] = match;
  175. sourceFiles = sourceFiles.split(',').map((file) => file.trim());
  176. sources = sources && sources.startsWith(',') ? sources.substring(1).trim() : '';
  177. } else {
  178. const comma = sources.indexOf(',');
  179. if (comma === -1) {
  180. source = sources;
  181. sources = '';
  182. } else {
  183. source = sources.substring(0, comma);
  184. sources = sources.substring(comma + 1);
  185. }
  186. }
  187. await Promise.all(targets.map((target) => {
  188. const dir = path.dirname(`${this.folder}/${target}`);
  189. return fs.mkdir(dir, { recursive: true });
  190. }));
  191. const response = await this.request(source);
  192. const buffer = await response.arrayBuffer();
  193. const data = new Uint8Array(buffer);
  194. if (sourceFiles.length > 0) {
  195. this.status({ name: 'decompress' });
  196. const archive = decompress(data);
  197. for (const name of sourceFiles) {
  198. this.status({ name: 'write', target: name });
  199. if (name === '.') {
  200. const target = targets.shift();
  201. const dir = path.join(this.folder, target);
  202. /* eslint-disable no-await-in-loop */
  203. await fs.mkdir(dir, { recursive: true });
  204. /* eslint-enable no-await-in-loop */
  205. } else {
  206. const stream = archive.entries.get(name);
  207. if (!stream) {
  208. throw new Error(`Entry not found '${name}. Archive contains entries: ${JSON.stringify(Array.from(archive.entries.keys()))} .`);
  209. }
  210. const target = targets.shift();
  211. const buffer = stream.peek();
  212. const file = path.join(this.folder, target);
  213. /* eslint-disable no-await-in-loop */
  214. await fs.writeFile(file, buffer, null);
  215. /* eslint-enable no-await-in-loop */
  216. }
  217. }
  218. } else {
  219. const target = targets.shift();
  220. this.status({ name: 'write', target });
  221. await fs.writeFile(`${this.folder}/${target}`, data, null);
  222. }
  223. if (targets.length > 0 && sources.length > 0) {
  224. await this.download(targets, sources);
  225. }
  226. }
  227. async load() {
  228. const target = path.resolve(this.folder, this.targets[0]);
  229. const identifier = path.basename(target);
  230. const stat = await fs.stat(target);
  231. let context = null;
  232. if (stat.isFile()) {
  233. const stream = new node.FileStream(target, 0, stat.size, stat.mtimeMs);
  234. const dirname = path.dirname(target);
  235. context = new mock.Context(this.host, dirname, identifier, stream, new Map());
  236. } else if (stat.isDirectory()) {
  237. const entries = new Map();
  238. const file = async (pathname) => {
  239. const stat = await fs.stat(pathname);
  240. const stream = new node.FileStream(pathname, 0, stat.size, stat.mtimeMs);
  241. const name = pathname.split(path.sep).join(path.posix.sep);
  242. entries.set(name, stream);
  243. };
  244. const walk = async (dir) => {
  245. const stats = await fs.readdir(dir, { withFileTypes: true });
  246. const promises = [];
  247. for (const stat of stats) {
  248. const pathname = path.join(dir, stat.name);
  249. if (stat.isDirectory()) {
  250. promises.push(walk(pathname));
  251. } else if (stat.isFile()) {
  252. promises.push(file(pathname));
  253. }
  254. }
  255. await Promise.all(promises);
  256. };
  257. await walk(target);
  258. context = new mock.Context(this.host, target, identifier, null, entries);
  259. }
  260. const modelFactoryService = new view.ModelFactoryService(this.host);
  261. this.model = await modelFactoryService.open(context);
  262. this.view.model = this.model;
  263. }
  264. async validate() {
  265. const model = this.model;
  266. if (!model.format || (this.format && this.format !== model.format)) {
  267. throw new Error(`Invalid model format '${model.format}'.`);
  268. }
  269. if (this.producer && model.producer !== this.producer) {
  270. throw new Error(`Invalid producer '${model.producer}'.`);
  271. }
  272. if (this.runtime && model.runtime !== this.runtime) {
  273. throw new Error(`Invalid runtime '${model.runtime}'.`);
  274. }
  275. if (model.metadata && (!Array.isArray(model.metadata) || !model.metadata.every((argument) => argument.name && (argument.value || argument.value === null || argument.value === '' || argument.value === false || argument.value === 0)))) {
  276. throw new Error("Invalid model metadata.");
  277. }
  278. if (this.assert) {
  279. for (const assert of this.assert) {
  280. const parts = assert.split('==').map((item) => item.trim());
  281. const properties = parts[0].split('.');
  282. const value = JSON.parse(parts[1].replace(/\s*'|'\s*/g, '"'));
  283. let context = { model };
  284. while (properties.length) {
  285. const property = properties.shift();
  286. if (context[property] !== undefined) {
  287. context = context[property];
  288. continue;
  289. }
  290. const match = /(.*)\[(.*)\]/.exec(property);
  291. if (match && match.length === 3 && context[match[1]] !== undefined) {
  292. const array = context[match[1]];
  293. const index = parseInt(match[2], 10);
  294. if (array[index] !== undefined) {
  295. context = array[index];
  296. continue;
  297. }
  298. }
  299. throw new Error(`Invalid property path '${parts[0]}'.`);
  300. }
  301. if (context !== value) {
  302. throw new Error(`Invalid '${context}' != '${assert}'.`);
  303. }
  304. }
  305. }
  306. if (model.version || model.description || model.author || model.license) {
  307. // continue
  308. }
  309. const validateGraph = async (graph) => {
  310. /* eslint-disable no-unused-expressions */
  311. const values = new Map();
  312. const validateValue = async (value) => {
  313. if (value === null) {
  314. return;
  315. }
  316. value.name.toString();
  317. value.name.length;
  318. value.description;
  319. if (value.quantization) {
  320. if (!this.tags.has('quantization')) {
  321. throw new Error("Invalid 'quantization' tag.");
  322. }
  323. const quantization = new view.Quantization(value.quantization);
  324. quantization.toString();
  325. }
  326. if (value.type) {
  327. value.type.toString();
  328. }
  329. if (value.initializer) {
  330. value.initializer.type.toString();
  331. if (value.initializer && value.initializer.peek && !value.initializer.peek()) {
  332. await value.initializer.read();
  333. }
  334. const tensor = new base.Tensor(value.initializer);
  335. if (!this.tags.has('skip-tensor-value')) {
  336. if (tensor.encoding !== '<' && tensor.encoding !== '>' && tensor.encoding !== '|') {
  337. throw new Error(`Tensor encoding '${tensor.encoding}' is not implemented.`);
  338. }
  339. if (tensor.layout && (tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo')) {
  340. throw new Error(`Tensor layout '${tensor.layout}' is not implemented.`);
  341. }
  342. if (!tensor.empty) {
  343. if (tensor.type && tensor.type.dataType === '?') {
  344. throw new Error('Tensor data type is not defined.');
  345. } else if (tensor.type && !tensor.type.shape) {
  346. throw new Error('Tensor shape is not defined.');
  347. } else {
  348. tensor.toString();
  349. if (this.tags.has('validation')) {
  350. const size = tensor.type.shape.dimensions.reduce((a, b) => a * b, 1);
  351. if (tensor.type && tensor.type.dataType !== '?' && size < 8192) {
  352. let data_type = '?';
  353. switch (tensor.type.dataType) {
  354. case 'boolean': data_type = 'bool'; break;
  355. case 'bfloat16': data_type = 'float32'; break;
  356. case 'float8e5m2': data_type = 'float16'; break;
  357. case 'float8e5m2fnuz': data_type = 'float16'; break;
  358. case 'float8e4m3fn': data_type = 'float16'; break;
  359. case 'float8e4m3fnuz': data_type = 'float16'; break;
  360. case 'int4': data_type = 'int8'; break;
  361. default: data_type = tensor.type.dataType; break;
  362. }
  363. Target.execution = Target.execution || new python.Execution();
  364. const execution = Target.execution;
  365. const io = execution.__import__('io');
  366. const numpy = execution.__import__('numpy');
  367. const bytes = new io.BytesIO();
  368. const dtype = new numpy.dtype(data_type);
  369. const array = numpy.asarray(tensor.value, dtype);
  370. numpy.save(bytes, array);
  371. }
  372. }
  373. }
  374. }
  375. }
  376. } else if (value.name.length === 0) {
  377. throw new Error('Empty value name.');
  378. }
  379. if (value.name.length > 0 && value.initializer === null) {
  380. if (!values.has(value.name)) {
  381. values.set(value.name, value);
  382. } else if (value !== values.get(value.name)) {
  383. throw new Error(`Duplicate value '${value.name}'.`);
  384. }
  385. }
  386. };
  387. const signatures = Array.isArray(graph.signatures) ? graph.signatures : [graph];
  388. for (const signature of signatures) {
  389. for (const input of signature.inputs) {
  390. input.name.toString();
  391. input.name.length;
  392. for (const value of input.value) {
  393. /* eslint-disable no-await-in-loop */
  394. await validateValue(value);
  395. /* eslint-enable no-await-in-loop */
  396. }
  397. }
  398. for (const output of signature.outputs) {
  399. output.name.toString();
  400. output.name.length;
  401. if (Array.isArray(output.value)) {
  402. for (const value of output.value) {
  403. /* eslint-disable no-await-in-loop */
  404. await validateValue(value);
  405. /* eslint-enable no-await-in-loop */
  406. }
  407. }
  408. }
  409. }
  410. if (graph.metadata && (!Array.isArray(graph.metadata) || !graph.metadata.every((argument) => argument.name && argument.value !== undefined))) {
  411. throw new Error("Invalid graph metadata.");
  412. }
  413. for (const node of graph.nodes) {
  414. const type = node.type;
  415. if (!type || typeof type.name !== 'string') {
  416. throw new Error(`Invalid node type '${JSON.stringify(node.type)}'.`);
  417. }
  418. if (Array.isArray(type.nodes)) {
  419. /* eslint-disable no-await-in-loop */
  420. await validateGraph(type);
  421. /* eslint-enable no-await-in-loop */
  422. }
  423. view.Documentation.open(type);
  424. node.name.toString();
  425. node.description;
  426. if (node.metadata && (!Array.isArray(node.metadata) || !node.metadata.every((argument) => argument.name && argument.value !== undefined))) {
  427. throw new Error("Invalid node metadata.");
  428. }
  429. const attributes = node.attributes;
  430. if (attributes) {
  431. for (const attribute of attributes) {
  432. attribute.name.toString();
  433. attribute.name.length;
  434. const type = attribute.type;
  435. const value = attribute.value;
  436. if ((type === 'graph' || type === 'function') && value && Array.isArray(value.nodes)) {
  437. /* eslint-disable no-await-in-loop */
  438. await validateGraph(value);
  439. /* eslint-enable no-await-in-loop */
  440. } else {
  441. let text = new view.Formatter(attribute.value, attribute.type).toString();
  442. if (text && text.length > 1000) {
  443. text = `${text.substring(0, 1000)}...`;
  444. }
  445. /* value = */ text.split('<');
  446. }
  447. }
  448. }
  449. const inputs = node.inputs;
  450. if (Array.isArray(inputs)) {
  451. for (const input of inputs) {
  452. input.name.toString();
  453. input.name.length;
  454. if (!input.type || input.type.endsWith('*')) {
  455. for (const value of input.value) {
  456. /* eslint-disable no-await-in-loop */
  457. await validateValue(value);
  458. /* eslint-enable no-await-in-loop */
  459. }
  460. if (this.tags.has('validation')) {
  461. if (input.value.length === 1 && input.value[0].initializer) {
  462. const sidebar = new view.TensorSidebar(this.view, input);
  463. sidebar.render();
  464. }
  465. }
  466. }
  467. }
  468. }
  469. const outputs = node.outputs;
  470. if (Array.isArray(outputs)) {
  471. for (const output of node.outputs) {
  472. output.name.toString();
  473. output.name.length;
  474. if (!output.type || output.type.endsWith('*')) {
  475. for (const value of output.value) {
  476. /* eslint-disable no-await-in-loop */
  477. await validateValue(value);
  478. /* eslint-enable no-await-in-loop */
  479. }
  480. }
  481. }
  482. }
  483. if (node.chain) {
  484. for (const chain of node.chain) {
  485. chain.name.toString();
  486. chain.name.length;
  487. }
  488. }
  489. const sidebar = new view.NodeSidebar(this.view, node);
  490. sidebar.render();
  491. }
  492. const sidebar = new view.ModelSidebar(this.view, this.model, graph);
  493. sidebar.render();
  494. /* eslint-enable no-unused-expressions */
  495. };
  496. const validateTarget = async (target) => {
  497. switch (target.type) {
  498. default: {
  499. await validateGraph(target);
  500. }
  501. }
  502. };
  503. for (const module of model.modules) {
  504. /* eslint-disable no-await-in-loop */
  505. await validateTarget(module);
  506. /* eslint-enable no-await-in-loop */
  507. }
  508. const functions = model.functions || [];
  509. for (const func of functions) {
  510. /* eslint-disable no-await-in-loop */
  511. await validateTarget(func);
  512. /* eslint-enable no-await-in-loop */
  513. }
  514. }
  515. async render() {
  516. for (const graph of this.model.modules) {
  517. const signatures = Array.isArray(graph.signatures) && graph.signatures.length > 0 ? graph.signatures : [graph];
  518. for (const signature of signatures) {
  519. /* eslint-disable no-await-in-loop */
  520. await this.view.render(graph, signature);
  521. /* eslint-enable no-await-in-loop */
  522. }
  523. }
  524. }
  525. }
  526. if (!worker_threads.isMainThread) {
  527. worker_threads.parentPort.addEventListener('message', async (e) => {
  528. const message = e.data;
  529. const response = {};
  530. try {
  531. const target = new Target(message);
  532. response.type = 'complete';
  533. response.target = target.name;
  534. target.on('status', (sender, message) => {
  535. message = { type: 'status', ...message };
  536. worker_threads.parentPort.postMessage(message);
  537. });
  538. if (message.measures) {
  539. target.measures = new Map();
  540. }
  541. await target.execute();
  542. response.measures = target.measures;
  543. } catch (error) {
  544. response.type = 'error';
  545. response.error = {
  546. name: error.name,
  547. message: error.message,
  548. stack: error.stack
  549. };
  550. const cause = error.cause;
  551. if (cause) {
  552. response.error.cause = {
  553. name: cause.name,
  554. message: cause.message
  555. };
  556. }
  557. }
  558. worker_threads.parentPort.postMessage(response);
  559. });
  560. }