2
0

worker.js 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. import * as base from '../source/base.js';
  2. import * as fs from 'fs/promises';
  3. import * as mock from './mock.js';
  4. import * as node from '../source/node.js';
  5. import * as path from 'path';
  6. import * as process from 'process';
  7. import * as python from '../source/python.js';
  8. import * as tar from '../source/tar.js';
  9. import * as url from 'url';
  10. import * as view from '../source/view.js';
  11. import * as worker_threads from 'worker_threads';
  12. import * as zip from '../source/zip.js';
  13. const access = async (path) => {
  14. try {
  15. await fs.access(path);
  16. return true;
  17. } catch {
  18. return false;
  19. }
  20. };
  21. const dirname = (...args) => {
  22. const file = url.fileURLToPath(import.meta.url);
  23. const dir = path.dirname(file);
  24. return path.join(dir, ...args);
  25. };
  26. const decompress = (buffer) => {
  27. let archive = zip.Archive.open(buffer, 'gzip');
  28. if (archive && archive.entries.size === 1) {
  29. const stream = archive.entries.values().next().value;
  30. buffer = stream.peek();
  31. }
  32. const formats = [zip, tar];
  33. for (const module of formats) {
  34. archive = module.Archive.open(buffer);
  35. if (archive) {
  36. break;
  37. }
  38. }
  39. return archive;
  40. };
  41. export class Target {
  42. constructor(item) {
  43. Object.assign(this, item);
  44. this.events = {};
  45. this.tags = new Set(this.tags);
  46. this.folder = item.type ? path.normalize(dirname('..', 'third_party' , 'test', item.type)) : process.cwd();
  47. this.assert = !this.assert || Array.isArray(this.assert) ? this.assert : [this.assert];
  48. this.serial = false;
  49. }
  50. on(event, callback) {
  51. this.events[event] = this.events[event] || [];
  52. this.events[event].push(callback);
  53. }
  54. emit(event, data) {
  55. if (this.events && this.events[event]) {
  56. for (const callback of this.events[event]) {
  57. callback(this, data);
  58. }
  59. }
  60. }
  61. status(message) {
  62. this.emit('status', message);
  63. }
  64. async execute() {
  65. if (this.measures) {
  66. this.measures.set('name', this.name);
  67. }
  68. await zip.Archive.import();
  69. const environment = { zoom: 'none', serial: this.serial };
  70. this.host = await new mock.Host(environment);
  71. this.view = new view.View(this.host);
  72. this.view.options.attributes = true;
  73. this.view.options.initializers = true;
  74. const time = async (method) => {
  75. const start = process.hrtime.bigint();
  76. let err = null;
  77. try {
  78. await method.call(this);
  79. } catch (error) {
  80. err = error;
  81. }
  82. const duration = Number(process.hrtime.bigint() - start) / 1e9;
  83. if (this.measures) {
  84. this.measures.set(method.name, duration);
  85. }
  86. if (err) {
  87. throw err;
  88. }
  89. };
  90. this.status({ name: 'name', target: this.name });
  91. const errors = [];
  92. try {
  93. await time(this.download);
  94. await time(this.load);
  95. await time(this.validate);
  96. if (!this.tags.has('skip-render')) {
  97. await time(this.render);
  98. }
  99. } catch (error) {
  100. errors.push(error);
  101. }
  102. errors.push(...this.host.errors);
  103. if (errors.length === 0 && this.error) {
  104. throw new Error('Expected error.');
  105. }
  106. if (errors.length > 0 && (!this.error || errors.map((error) => error.message).join('\n') !== this.error)) {
  107. throw errors[0];
  108. }
  109. this.view.dispose();
  110. }
  111. async request(url, init) {
  112. const response = await fetch(url, init);
  113. if (!response.ok) {
  114. throw new Error(response.status.toString());
  115. }
  116. if (response.body) {
  117. const reader = response.body.getReader();
  118. const length = response.headers.has('Content-Length') ? parseInt(response.headers.get('Content-Length'), 10) : -1;
  119. let position = 0;
  120. /* eslint-disable consistent-this */
  121. const target = this;
  122. /* eslint-enable consistent-this */
  123. const stream = new ReadableStream({
  124. async start(controller) {
  125. const read = async () => {
  126. try {
  127. const result = await reader.read();
  128. if (result.done) {
  129. target.status({ name: 'download' });
  130. controller.close();
  131. } else {
  132. position += result.value.length;
  133. if (length >= 0) {
  134. const percent = position / length;
  135. target.status({ name: 'download', target: url, percent });
  136. } else {
  137. target.status({ name: 'download', target: url, position });
  138. }
  139. controller.enqueue(result.value);
  140. return await read();
  141. }
  142. } catch (error) {
  143. controller.error(error);
  144. throw error;
  145. }
  146. return null;
  147. };
  148. return read();
  149. }
  150. });
  151. return new Response(stream, {
  152. status: response.status,
  153. statusText: response.statusText,
  154. headers: response.headers
  155. });
  156. }
  157. return response;
  158. }
  159. async download(targets, sources) {
  160. targets = targets || Array.from(this.targets);
  161. sources = sources || this.source;
  162. const files = targets.map((file) => path.resolve(this.folder, file));
  163. const exists = await Promise.all(files.map((file) => access(file)));
  164. if (exists.every((value) => value)) {
  165. return;
  166. }
  167. if (!sources) {
  168. throw new Error('Download source not specified.');
  169. }
  170. let source = '';
  171. let sourceFiles = [];
  172. const match = sources.match(/^(.*?)\[(.*?)\](.*)$/);
  173. if (match) {
  174. [, source, sourceFiles, sources] = match;
  175. sourceFiles = sourceFiles.split(',').map((file) => file.trim());
  176. sources = sources && sources.startsWith(',') ? sources.substring(1).trim() : '';
  177. } else {
  178. const comma = sources.indexOf(',');
  179. if (comma === -1) {
  180. source = sources;
  181. sources = '';
  182. } else {
  183. source = sources.substring(0, comma);
  184. sources = sources.substring(comma + 1);
  185. }
  186. }
  187. await Promise.all(targets.map((target) => {
  188. const dir = path.dirname(`${this.folder}/${target}`);
  189. return fs.mkdir(dir, { recursive: true });
  190. }));
  191. const response = await this.request(source);
  192. const buffer = await response.arrayBuffer();
  193. const data = new Uint8Array(buffer);
  194. if (sourceFiles.length > 0) {
  195. this.status({ name: 'decompress' });
  196. const archive = decompress(data);
  197. for (const name of sourceFiles) {
  198. this.status({ name: 'write', target: name });
  199. if (name === '.') {
  200. const target = targets.shift();
  201. const dir = path.join(this.folder, target);
  202. /* eslint-disable no-await-in-loop */
  203. await fs.mkdir(dir, { recursive: true });
  204. /* eslint-enable no-await-in-loop */
  205. } else {
  206. const stream = archive.entries.get(name);
  207. if (!stream) {
  208. throw new Error(`Entry not found '${name}. Archive contains entries: ${JSON.stringify(Array.from(archive.entries.keys()))} .`);
  209. }
  210. const target = targets.shift();
  211. const buffer = stream.peek();
  212. const file = path.join(this.folder, target);
  213. /* eslint-disable no-await-in-loop */
  214. await fs.writeFile(file, buffer, null);
  215. /* eslint-enable no-await-in-loop */
  216. }
  217. }
  218. } else {
  219. const target = targets.shift();
  220. this.status({ name: 'write', target });
  221. await fs.writeFile(`${this.folder}/${target}`, data, null);
  222. }
  223. if (targets.length > 0 && sources.length > 0) {
  224. await this.download(targets, sources);
  225. }
  226. }
  227. async load() {
  228. const target = path.resolve(this.folder, this.targets[0]);
  229. const identifier = path.basename(target);
  230. const stat = await fs.stat(target);
  231. let context = null;
  232. if (stat.isFile()) {
  233. const stream = new node.FileStream(target, 0, stat.size, stat.mtimeMs);
  234. const dirname = path.dirname(target);
  235. context = new mock.Context(this.host, dirname, identifier, stream, new Map());
  236. } else if (stat.isDirectory()) {
  237. const entries = new Map();
  238. const file = async (pathname) => {
  239. const stat = await fs.stat(pathname);
  240. const stream = new node.FileStream(pathname, 0, stat.size, stat.mtimeMs);
  241. const name = pathname.split(path.sep).join(path.posix.sep);
  242. entries.set(name, stream);
  243. };
  244. const walk = async (dir) => {
  245. const stats = await fs.readdir(dir, { withFileTypes: true });
  246. const promises = [];
  247. for (const stat of stats) {
  248. const pathname = path.join(dir, stat.name);
  249. if (stat.isDirectory()) {
  250. promises.push(walk(pathname));
  251. } else if (stat.isFile()) {
  252. promises.push(file(pathname));
  253. }
  254. }
  255. await Promise.all(promises);
  256. };
  257. await walk(target);
  258. context = new mock.Context(this.host, target, identifier, null, entries);
  259. }
  260. const modelFactoryService = new view.ModelFactoryService(this.host);
  261. this.model = await modelFactoryService.open(context);
  262. this.view.model = this.model;
  263. }
  264. async validate() {
  265. const model = this.model;
  266. if (!model.format || (this.format && this.format !== model.format)) {
  267. throw new Error(`Invalid model format '${model.format}'.`);
  268. }
  269. if (this.producer && model.producer !== this.producer) {
  270. throw new Error(`Invalid producer '${model.producer}'.`);
  271. }
  272. if (this.runtime && model.runtime !== this.runtime) {
  273. throw new Error(`Invalid runtime '${model.runtime}'.`);
  274. }
  275. if (model.metadata && (!Array.isArray(model.metadata) || !model.metadata.every((argument) => argument.name && (argument.value || argument.value === null || argument.value === '' || argument.value === false || argument.value === 0)))) {
  276. throw new Error("Invalid model metadata.");
  277. }
  278. if (this.assert) {
  279. for (const assert of this.assert) {
  280. const parts = assert.split('==').map((item) => item.trim());
  281. const properties = parts[0].split('.');
  282. const value = JSON.parse(parts[1].replace(/\s*'|'\s*/g, '"'));
  283. let context = { model };
  284. while (properties.length) {
  285. const property = properties.shift();
  286. if (context[property] !== undefined) {
  287. context = context[property];
  288. continue;
  289. }
  290. const match = /(.*)\[(.*)\]/.exec(property);
  291. if (match && match.length === 3 && context[match[1]] !== undefined) {
  292. const array = context[match[1]];
  293. const index = parseInt(match[2], 10);
  294. if (array[index] !== undefined) {
  295. context = array[index];
  296. continue;
  297. }
  298. }
  299. throw new Error(`Invalid property path '${parts[0]}'.`);
  300. }
  301. if (context !== value) {
  302. throw new Error(`Invalid '${context}' != '${assert}'.`);
  303. }
  304. }
  305. }
  306. if (model.version || model.description || model.author || model.license) {
  307. // continue
  308. }
  309. const validateGraph = async (graph) => {
  310. /* eslint-disable no-unused-expressions */
  311. const values = new Map();
  312. const validateValue = async (value) => {
  313. if (value === null) {
  314. return;
  315. }
  316. value.name.toString();
  317. value.name.length;
  318. value.description;
  319. if (value.quantization) {
  320. if (!this.tags.has('quantization')) {
  321. throw new Error("Invalid 'quantization' tag.");
  322. }
  323. const quantization = new view.Quantization(value.quantization);
  324. quantization.toString();
  325. }
  326. if (value.type) {
  327. value.type.toString();
  328. }
  329. if (value.initializer) {
  330. value.initializer.type.toString();
  331. if (value.initializer && value.initializer.peek && !value.initializer.peek()) {
  332. await value.initializer.read();
  333. }
  334. const tensor = new base.Tensor(value.initializer);
  335. if (!this.tags.has('skip-tensor-value')) {
  336. if (tensor.encoding !== '<' && tensor.encoding !== '>' && tensor.encoding !== '|') {
  337. throw new Error(`Tensor encoding '${tensor.encoding}' is not implemented.`);
  338. }
  339. if (tensor.layout && (tensor.layout !== 'sparse' && tensor.layout !== 'sparse.coo')) {
  340. throw new Error(`Tensor layout '${tensor.layout}' is not implemented.`);
  341. }
  342. if (!tensor.empty) {
  343. if (tensor.type && tensor.type.dataType === '?') {
  344. throw new Error('Tensor data type is not defined.');
  345. } else if (tensor.type && !tensor.type.shape) {
  346. throw new Error('Tensor shape is not defined.');
  347. } else {
  348. tensor.toString();
  349. if (this.tags.has('validation')) {
  350. const size = tensor.type.shape.dimensions.reduce((a, b) => a * b, 1);
  351. if (size < 8192 && tensor.type &&
  352. tensor.type.dataType !== '?' &&
  353. tensor.type.dataType !== 'string' &&
  354. tensor.type.dataType !== 'int128' &&
  355. tensor.type.dataType !== 'complex<int32>') {
  356. let data_type = '?';
  357. switch (tensor.type.dataType) {
  358. case 'boolean': data_type = 'bool'; break;
  359. case 'bfloat16': data_type = 'float32'; break;
  360. case 'float4e2m1fn': data_type = 'float16'; break;
  361. case 'float6e2m3fn': data_type = 'float16'; break;
  362. case 'float6e3m2fn': data_type = 'float16'; break;
  363. case 'float8e5m2': data_type = 'float16'; break;
  364. case 'float8e5m2fnuz': data_type = 'float16'; break;
  365. case 'float8e3m4': data_type = 'float16'; break;
  366. case 'float8e4m3': data_type = 'float16'; break;
  367. case 'float8e4m3fn': data_type = 'float16'; break;
  368. case 'float8e4m3fnuz': data_type = 'float16'; break;
  369. case 'float8e4m3b11fnuz': data_type = 'float16'; break;
  370. case 'float8e8m0fnu': data_type = 'float16'; break;
  371. case 'complex<float32>': data_type = 'complex64'; break;
  372. case 'complex<float64>': data_type = 'complex128'; break;
  373. case 'int1': data_type = 'int8'; break;
  374. case 'int2': data_type = 'int8'; break;
  375. case 'int4': data_type = 'int8'; break;
  376. case 'int48': data_type = 'int64'; break;
  377. case 'uint2': data_type = 'uint8'; break;
  378. case 'uint4': data_type = 'uint8'; break;
  379. default: data_type = tensor.type.dataType; break;
  380. }
  381. Target.execution = Target.execution || new python.Execution();
  382. const execution = Target.execution;
  383. const io = execution.__import__('io');
  384. const numpy = execution.__import__('numpy');
  385. const bytes = new io.BytesIO();
  386. const dtype = new numpy.dtype(data_type);
  387. const array = numpy.asarray(tensor.value, dtype);
  388. numpy.save(bytes, array);
  389. }
  390. }
  391. }
  392. }
  393. }
  394. } else if (value.name.length === 0) {
  395. throw new Error('Empty value name.');
  396. }
  397. if (value.name.length > 0 && value.initializer === null) {
  398. if (!values.has(value.name)) {
  399. values.set(value.name, value);
  400. } else if (value !== values.get(value.name)) {
  401. throw new Error(`Duplicate value '${value.name}'.`);
  402. }
  403. }
  404. };
  405. const signatures = Array.isArray(graph.signatures) ? graph.signatures : [graph];
  406. for (const signature of signatures) {
  407. for (const input of signature.inputs) {
  408. input.name.toString();
  409. input.name.length;
  410. for (const value of input.value) {
  411. /* eslint-disable no-await-in-loop */
  412. await validateValue(value);
  413. /* eslint-enable no-await-in-loop */
  414. }
  415. }
  416. for (const output of signature.outputs) {
  417. output.name.toString();
  418. output.name.length;
  419. if (Array.isArray(output.value)) {
  420. for (const value of output.value) {
  421. /* eslint-disable no-await-in-loop */
  422. await validateValue(value);
  423. /* eslint-enable no-await-in-loop */
  424. }
  425. }
  426. }
  427. }
  428. if (graph.metadata && (!Array.isArray(graph.metadata) || !graph.metadata.every((argument) => argument.name && argument.value !== undefined))) {
  429. throw new Error("Invalid graph metadata.");
  430. }
  431. for (const node of graph.nodes) {
  432. const type = node.type;
  433. if (!type || typeof type.name !== 'string') {
  434. throw new Error(`Invalid node type '${JSON.stringify(node.type)}'.`);
  435. }
  436. if (Array.isArray(type.nodes)) {
  437. /* eslint-disable no-await-in-loop */
  438. await validateGraph(type);
  439. /* eslint-enable no-await-in-loop */
  440. }
  441. view.Documentation.open(type);
  442. node.name.toString();
  443. node.description;
  444. if (node.metadata && (!Array.isArray(node.metadata) || !node.metadata.every((argument) => argument.name && argument.value !== undefined))) {
  445. throw new Error("Invalid node metadata.");
  446. }
  447. const attributes = node.attributes;
  448. if (attributes) {
  449. for (const attribute of attributes) {
  450. attribute.name.toString();
  451. attribute.name.length;
  452. const type = attribute.type;
  453. const value = attribute.value;
  454. if ((type === 'graph' || type === 'function') && value && Array.isArray(value.nodes)) {
  455. /* eslint-disable no-await-in-loop */
  456. await validateGraph(value);
  457. /* eslint-enable no-await-in-loop */
  458. } else {
  459. let text = new view.Formatter(attribute.value, attribute.type).toString();
  460. if (text && text.length > 1000) {
  461. text = `${text.substring(0, 1000)}...`;
  462. }
  463. /* value = */ text.split('<');
  464. }
  465. }
  466. }
  467. const inputs = node.inputs;
  468. if (Array.isArray(inputs)) {
  469. for (const input of inputs) {
  470. input.name.toString();
  471. input.name.length;
  472. if (!input.type || input.type.endsWith('*')) {
  473. for (const value of input.value) {
  474. /* eslint-disable no-await-in-loop */
  475. await validateValue(value);
  476. /* eslint-enable no-await-in-loop */
  477. }
  478. if (this.tags.has('validation')) {
  479. if (input.value.length === 1 && input.value[0].initializer) {
  480. const sidebar = new view.TensorSidebar(this.view, input);
  481. sidebar.render();
  482. }
  483. }
  484. }
  485. }
  486. }
  487. const outputs = node.outputs;
  488. if (Array.isArray(outputs)) {
  489. for (const output of node.outputs) {
  490. output.name.toString();
  491. output.name.length;
  492. if (!output.type || output.type.endsWith('*')) {
  493. for (const value of output.value) {
  494. /* eslint-disable no-await-in-loop */
  495. await validateValue(value);
  496. /* eslint-enable no-await-in-loop */
  497. }
  498. }
  499. }
  500. }
  501. if (node.chain) {
  502. for (const chain of node.chain) {
  503. chain.name.toString();
  504. chain.name.length;
  505. }
  506. }
  507. const sidebar = new view.NodeSidebar(this.view, node);
  508. sidebar.render();
  509. }
  510. const sidebar = new view.ModelSidebar(this.view, this.model, graph);
  511. sidebar.render();
  512. /* eslint-enable no-unused-expressions */
  513. };
  514. const validateTarget = async (target) => {
  515. switch (target.type) {
  516. default: {
  517. await validateGraph(target);
  518. }
  519. }
  520. };
  521. for (const module of model.modules) {
  522. /* eslint-disable no-await-in-loop */
  523. await validateTarget(module);
  524. /* eslint-enable no-await-in-loop */
  525. }
  526. const functions = model.functions || [];
  527. for (const func of functions) {
  528. /* eslint-disable no-await-in-loop */
  529. await validateTarget(func);
  530. /* eslint-enable no-await-in-loop */
  531. }
  532. }
  533. async render() {
  534. for (const graph of this.model.modules) {
  535. const signatures = Array.isArray(graph.signatures) && graph.signatures.length > 0 ? graph.signatures : [graph];
  536. for (const signature of signatures) {
  537. /* eslint-disable no-await-in-loop */
  538. await this.view.render(graph, signature);
  539. /* eslint-enable no-await-in-loop */
  540. }
  541. }
  542. }
  543. }
  544. if (!worker_threads.isMainThread) {
  545. worker_threads.parentPort.addEventListener('message', async (e) => {
  546. const message = e.data;
  547. const response = {};
  548. try {
  549. const target = new Target(message);
  550. response.type = 'complete';
  551. response.target = target.name;
  552. target.on('status', (sender, message) => {
  553. message = { type: 'status', ...message };
  554. worker_threads.parentPort.postMessage(message);
  555. });
  556. if (message.measures) {
  557. target.measures = new Map();
  558. }
  559. await target.execute();
  560. response.measures = target.measures;
  561. } catch (error) {
  562. response.type = 'error';
  563. response.error = {
  564. name: error.name,
  565. message: error.message,
  566. stack: error.stack
  567. };
  568. const cause = error.cause;
  569. if (cause) {
  570. response.error.cause = {
  571. name: cause.name,
  572. message: cause.message
  573. };
  574. }
  575. }
  576. worker_threads.parentPort.postMessage(response);
  577. });
  578. }