megengine.js 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. // Experimental
  2. import * as flatbuffers from './flatbuffers.js';
  3. const megengine = {};
  4. megengine.ModelFactory = class {
  5. async match(context) {
  6. const stream = context.stream;
  7. if (stream && stream.length >= 12) {
  8. let buffer = stream.peek(12);
  9. const tag = String.fromCharCode.apply(null, buffer);
  10. const position = tag.startsWith('mgbtest0') ? 12 : 0;
  11. if (stream.length > (position + 12)) {
  12. buffer = stream.peek(24).slice(position, position + 12);
  13. const size = (buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer[3] << 24) >>> 0;
  14. if (position > 0 || size === (stream.length - position - 4)) {
  15. const reader = flatbuffers.BinaryReader.open(stream, position + 4);
  16. if (reader.identifier === 'mgv2') {
  17. return context.set('megengine.mge', reader);
  18. }
  19. }
  20. }
  21. for (const value of ['mgb0001', 'mgb0000a', 'MGBS', 'MGBC']) {
  22. if (tag.startsWith(value)) {
  23. return context.set(`megengine.${value}`);
  24. }
  25. }
  26. }
  27. const obj = await context.peek('pkl');
  28. if (obj && obj.__class__ && obj.__class__.__module__ === 'megengine.traced_module.traced_module' && obj.__class__.__name__ === 'TracedModule') {
  29. return context.set('megengine.tm');
  30. }
  31. return null;
  32. }
  33. async open(context) {
  34. const metadata = await context.metadata('megengine-metadata.json');
  35. switch (context.type) {
  36. case 'megengine.tm': {
  37. const obj = await context.peek('pkl');
  38. return new megengine.Model(metadata, obj, context.type);
  39. }
  40. case 'megengine.mge': {
  41. megengine.schema = await context.require('./megengine-schema');
  42. megengine.schema = megengine.schema.mgb.serialization.fbs;
  43. let model = null;
  44. try {
  45. const reader = context.value;
  46. model = megengine.schema.v2.Model.create(reader);
  47. } catch (error) {
  48. const message = error && error.message ? error.message : error.toString();
  49. throw new megengine.Error(`File format is not megengine.Model (${message.replace(/\.$/, '')}).`);
  50. }
  51. return new megengine.Model(metadata, model, context.type);
  52. }
  53. default: {
  54. throw new megengine.Error(`Unsupported MegEngine format '${context.type.replace(/^megengine\./, '')}'.`);
  55. }
  56. }
  57. }
  58. };
  59. megengine.Model = class {
  60. constructor(metadata, obj, type) {
  61. this.format = 'MegEngine';
  62. if (type === 'megengine.tm') {
  63. this.format += obj.dump_info && obj.dump_info.has('version') ? ` v${obj.dump_info.get('version')}` : '';
  64. } else if (type === 'megengine.mge') {
  65. this.format += ` Mge${obj.model_version ? ` v${obj.model_version}` : ''}`;
  66. }
  67. this.modules = [new megengine.Graph(metadata, obj)];
  68. }
  69. };
  70. megengine.Graph = class {
  71. constructor(metadata, obj) {
  72. this.name = '';
  73. this.nodes = [];
  74. this.inputs = [];
  75. this.outputs = [];
  76. const values = new Map();
  77. values.map = (name, type, tensor, quantization) => {
  78. type = type || null;
  79. tensor = tensor || null;
  80. if (tensor) {
  81. return new megengine.Value(name, type, tensor, quantization);
  82. }
  83. if (!values.has(name)) {
  84. values.set(name, new megengine.Value(name, type, tensor, quantization));
  85. } else if ((type && !type.equals(values.get(name).type)) || tensor) {
  86. throw new megengine.Error(`Duplicate value '${name}'.`);
  87. }
  88. return values.get(name);
  89. };
  90. const loadGraph = (tmodule, igraph, context, namePrefix, metadata, isRoot) => {
  91. const expressions = igraph._exprs;
  92. const getTensorType = (dtype, shape) => {
  93. dtype = dtype ? dtype.__name__ : null;
  94. return new megengine.TensorType(dtype, new megengine.TensorShape(shape));
  95. };
  96. if (isRoot) {
  97. for (const node of igraph._inputs) {
  98. if (node.__class__.__name__ !== 'ModuleNode') {
  99. const type = getTensorType(node._dtype, node._shape);
  100. const value = values.map(node._name, type, null);
  101. const argument = new megengine.Argument(node._name, [value]);
  102. this.inputs.push(argument);
  103. }
  104. }
  105. for (const node of igraph._outputs) {
  106. const type = getTensorType(node._dtype, node._shape);
  107. const value = values.map(node._name, type, null);
  108. const argument = new megengine.Argument(node._name, [value]);
  109. this.outputs.push(argument);
  110. }
  111. }
  112. const parseGetAttr = (module, expression) => {
  113. let names = expression.name.split('.');
  114. while (expression.inputs[0].expr.__class__.__name__ === 'GetAttr') {
  115. expression = expression.inputs[0].expr;
  116. names = expression.name.split('.').concat(names);
  117. }
  118. let obj = module;
  119. for (const name of names) {
  120. if (obj.__class__.__name__ === '_ModuleState') {
  121. obj = obj.state;
  122. }
  123. obj = obj[name];
  124. }
  125. return obj;
  126. };
  127. const parseArgs = (args, kwargs, meta) => {
  128. const state = {};
  129. let argIdx = 0;
  130. const processArgs = (inp, startIdx) => {
  131. while (typeof inp === 'string' && inp.indexOf('Tensor') !== -1) {
  132. inp = inp.replace('Tensor', `input${startIdx === 0 ? '' : startIdx}`);
  133. startIdx += 1;
  134. }
  135. return [inp, startIdx];
  136. };
  137. const formatTreeDef = (obj) => {
  138. if (obj.__class__.__name__ !== 'TreeDef' && obj.__class__.__name__ !== 'LeafDef') {
  139. throw new megengine.Error(`Invalid argument '${obj.__class__.__name__}'.`);
  140. }
  141. if (obj.__class__.__name__ === 'TreeDef') {
  142. const type = typeof obj.type === 'string' ? obj.type.split('.').slice(-1)[0] : obj.type.__name__;
  143. const list = obj.children_defs.map((child) => formatTreeDef(child));
  144. switch (type) {
  145. case 'tuple': {
  146. return `(${list.join(',')})`;
  147. }
  148. case 'slice': {
  149. return list.join(':');
  150. }
  151. case 'list': {
  152. return `[${list.join(',')}]`;
  153. }
  154. case 'dict': {
  155. let content = '';
  156. for (let i = 0; i < this.children_defs.length; i++) {
  157. content += `${this.aux_data[i]}:${list[i]}`;
  158. }
  159. return `{${content}}`;
  160. }
  161. default: {
  162. return `${type}(${list.join(',')})`;
  163. }
  164. }
  165. }
  166. if (obj.const_val !== null) {
  167. return obj.const_val;
  168. } else if (obj.type[0].__module__ !== undefined) {
  169. return obj.type[0].__name__;
  170. }
  171. return 'None';
  172. };
  173. let inpIdx = 0;
  174. for (const arg of args.children_defs) {
  175. let name = '';
  176. if (meta.attributes === undefined || (meta.attributes.length !== args.children_defs.length && meta.varargs === null)) {
  177. name = `arg${argIdx}`;
  178. } else if (argIdx < meta.attributes.length) {
  179. name = meta.attributes[argIdx].name;
  180. } else {
  181. name = meta.varargs + (argIdx - meta.attributes.length);
  182. }
  183. const [value, index] = processArgs(formatTreeDef(arg), inpIdx);
  184. state[name] = value;
  185. inpIdx = index;
  186. argIdx += 1;
  187. }
  188. for (let i = 0; i < kwargs.children_defs.length; i++) {
  189. const [value, index] = processArgs(formatTreeDef(kwargs.children_defs[i]), inpIdx);
  190. state[kwargs.aux_data[i]] = value;
  191. inpIdx = index;
  192. }
  193. return state;
  194. };
  195. const getName = (context, name) => {
  196. let rst = name;
  197. while (context.get(rst) !== undefined) {
  198. if (rst === context.get(rst)) {
  199. return rst;
  200. }
  201. rst = context.get(rst);
  202. }
  203. return rst;
  204. };
  205. const getFullName = (prefix, name) => {
  206. return prefix === '' ? name : `${prefix}_${name}`;
  207. };
  208. for (const expression of expressions) {
  209. const type = expression.__class__.__name__;
  210. for (const input of expression.inputs) {
  211. input._fullname = getName(context, getFullName(namePrefix, input._name));
  212. }
  213. for (const output of expression.outputs) {
  214. output._fullname = getName(context, getFullName(namePrefix, output._name));
  215. }
  216. switch (type) {
  217. case 'Input': {
  218. break;
  219. }
  220. case 'GetAttr': {
  221. if (expression.outputs[0].__class__.__name__ === 'TensorNode') {
  222. const tensor = parseGetAttr(tmodule, expression);
  223. const type = getTensorType(tensor.dtype, tensor.data.shape);
  224. const data = tensor.data.data;
  225. expression.outputs[0].initializer = new megengine.Tensor(expression.name, type, data);
  226. }
  227. break;
  228. }
  229. case 'Constant': {
  230. if (expression.outputs[0].__class__.__name__ === 'TensorNode') {
  231. const tensor = expression.value;
  232. const type = getTensorType(tensor.dtype, tensor.data.shape);
  233. const data = tensor.data.data;
  234. expression.outputs[0].initializer = new megengine.Tensor('', type, data);
  235. }
  236. break;
  237. }
  238. case 'CallMethod': {
  239. const obj = { 'name': '' };
  240. let state = {};
  241. if (expression.method === '__call__') {
  242. const module = parseGetAttr(tmodule, expression.inputs[0].expr);
  243. const getModuleType = (obj) => {
  244. if (obj.module !== undefined) {
  245. return `${obj.module[0]}.${obj.module[1]}`;
  246. }
  247. return `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  248. };
  249. const moduleType = module.__class__.__name__ === 'TracedModule' ? 'TracedModule' : getModuleType(module);
  250. if (moduleType === 'TracedModule') {
  251. const moduleName = expression.outputs[0]._name.endsWith("_out") ? expression.outputs[0]._name.substring(0, expression.outputs[0]._name.length - 4) : expression.outputs[0]._name;
  252. const prefix = getFullName(namePrefix, moduleName);
  253. const internalGraph = module.argdef_graph_map.get(expression.arg_def);
  254. for (let i = 0; i < expression.inputs.length; i++) {
  255. const actualName = getFullName(namePrefix, expression.inputs[i]._name);
  256. const internalName = getFullName(prefix, internalGraph._inputs[i]._name);
  257. context.set(internalName, actualName);
  258. }
  259. for (let i = 0; i < expression.outputs.length; i++) {
  260. const actualName = getFullName(namePrefix, expression.outputs[i]._name);
  261. const internalName = getFullName(prefix, internalGraph._outputs[i]._name);
  262. if (context.get(internalName) === undefined) {
  263. context.set(internalName, actualName);
  264. } else {
  265. context.set(actualName, context.get(internalName));
  266. }
  267. }
  268. loadGraph(module, internalGraph, context, prefix, metadata, false);
  269. continue;
  270. }
  271. obj.type = moduleType;
  272. state = module.__class__.__name__ === 'TracedModule' ? module : module.state;
  273. if (state === undefined) {
  274. state = module;
  275. }
  276. } else {
  277. obj.type = expression.method;
  278. const [args, kwargs] = expression.arg_def.children_defs;
  279. const schema = metadata.type(expression.method);
  280. state = parseArgs(args, kwargs, schema);
  281. }
  282. const node = new megengine.Node(metadata, obj, values, null, expression, state);
  283. this.nodes.push(node);
  284. break;
  285. }
  286. case 'CallFunction': {
  287. const getFunctionType = (obj) => {
  288. if (obj.func.__module__ !== undefined) {
  289. return `${obj.func.__module__}.${obj.func.__name__}`;
  290. }
  291. return `${obj.func[0]}.${obj.func[1]}`;
  292. };
  293. const func = getFunctionType(expression);
  294. const item = { 'name': '', 'type': func };
  295. const [args, kwargs] = expression.arg_def.children_defs;
  296. const schema = metadata.type(func);
  297. const state = parseArgs(args, kwargs, schema);
  298. const node = new megengine.Node(metadata, item, values, null, expression, state);
  299. this.nodes.push(node);
  300. break;
  301. }
  302. case 'Apply': {
  303. const opdef = expression.opdef_state ? expression.opdef_state.get('opdef_type') : expression.opdef.type;
  304. const item = { 'name': '', 'type': `${opdef.__module__}.${opdef.__name__}` };
  305. const node = new megengine.Node(metadata, item, values, null, expression, expression.opdef_state);
  306. this.nodes.push(node);
  307. break;
  308. }
  309. default: {
  310. break;
  311. }
  312. }
  313. }
  314. };
  315. if (obj.argdef_graph_map) {
  316. const [graph] = Array.from(obj.argdef_graph_map.values());
  317. loadGraph(obj, graph, new Map(), '', metadata, true);
  318. return;
  319. }
  320. const extraInfoNameset = new Set();
  321. const getExtraInfo = (opr) => {
  322. let name = opr.name;
  323. let repeatIdx = 0;
  324. while (extraInfoNameset.has(name)) {
  325. for (const id of opr.inputs) {
  326. name = `${name}[${id}]`;
  327. }
  328. name += repeatIdx;
  329. repeatIdx += 1;
  330. }
  331. extraInfoNameset.add(name);
  332. const type = opr.type.replace(/V(\d+)$/, '');
  333. const args = [];
  334. if (opr.tensors.length > 0) {
  335. const [tensor] = opr.tensors;
  336. const type = new megengine.TensorType(tensor.dtype.type, new megengine.TensorShape(tensor.shape));
  337. const data = tensor.data.byteLength === 0 ? undefined : tensor.data.slice(0);
  338. const initializer = opr.type === 'Host2DeviceCopy' ? undefined : new megengine.Tensor('', type, data);
  339. const quantization = tensor.dtype.param ? { scale: tensor.dtype.param.scale, zeroPoint: tensor.dtype.param.zero_point } : null;
  340. const value = values.map(name, type, initializer, quantization);
  341. args.push(value);
  342. } else {
  343. const type = opr.shape ? new megengine.TensorType('?', new megengine.TensorShape(opr.shape)) : null;
  344. const value = values.map(name, type);
  345. args.push(value);
  346. }
  347. return { name, type, args };
  348. };
  349. const getAllOprAndTensor = (oprs) => {
  350. const allOprAndTensor = new Map();
  351. for (const opr of oprs) {
  352. if (opr.type === 'MultipleDeviceTensorWithFormatHolder' || opr.outputs.length > 1) {
  353. if (opr.type === 'MultipleDeviceTensorWithFormatHolder' || opr.type === 'MultipleDeviceTensorHolder') {
  354. opr.type = 'ImmutableTensor';
  355. }
  356. for (let id = 0; id < opr.outputs.length; id++) {
  357. const keyId = opr.outputs[id];
  358. const name = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].name : String(keyId);
  359. const type = opr.type;
  360. const tensors = opr.tensors.length ? [opr.tensors[id]] : [];
  361. const onlyShape = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].shape : [];
  362. allOprAndTensor.set(keyId, { name, type, tensors, shape: onlyShape, inputs: opr.inputs, outputs: opr.outputs });
  363. const _opr = allOprAndTensor.get(keyId);
  364. _opr.extraInfo = getExtraInfo(_opr);
  365. }
  366. } else {
  367. const [keyId] = opr.outputs;
  368. opr.name = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].name : String(keyId);
  369. if (obj.middle_tensors[keyId] && obj.middle_tensors[keyId].shape) {
  370. opr.shape = obj.middle_tensors[keyId].shape;
  371. }
  372. allOprAndTensor.set(keyId, opr);
  373. const _opr = allOprAndTensor.get(keyId);
  374. _opr.extraInfo = getExtraInfo(_opr);
  375. }
  376. }
  377. return allOprAndTensor;
  378. };
  379. const allOprAndTensor = getAllOprAndTensor(obj.oprs);
  380. for (const op of Array.from(allOprAndTensor.values())) {
  381. if (op.type === 'Host2DeviceCopy') {
  382. const argument = new megengine.Argument('input', op.extraInfo.args);
  383. this.inputs.push(argument);
  384. } else if (op.type !== 'ImmutableTensor') {
  385. this.nodes.push(new megengine.Node(metadata, op, values, allOprAndTensor));
  386. }
  387. }
  388. for (let i = 0; i < obj.output_vars_idx.length; i++) {
  389. const id = obj.output_vars_idx[i].compact_id;
  390. const out_type = `output${i === 0 ? '' : i}`;
  391. const argument = new megengine.Argument(out_type, allOprAndTensor.get(id).extraInfo.args);
  392. this.outputs.push(argument);
  393. }
  394. }
  395. };
  396. megengine.Argument = class {
  397. constructor(name, value, type = null, visible = true) {
  398. this.name = name;
  399. this.value = value;
  400. this.type = type;
  401. this.visible = visible;
  402. }
  403. };
  404. megengine.Value = class {
  405. constructor(name, type, initializer, quantization) {
  406. if (typeof name !== 'string') {
  407. throw new megengine.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  408. }
  409. this.name = name;
  410. this.type = !type && initializer ? initializer.type : type;
  411. this.initializer = initializer;
  412. if (quantization && ((quantization.scale !== undefined && quantization.scale !== 1) || (quantization.zeroPoint !== undefined && quantization.zeroPoint !== 0))) {
  413. this.quantization = {
  414. type: 'linear',
  415. scale: [quantization.scale],
  416. offset: [quantization.zeroPoint]
  417. };
  418. }
  419. }
  420. };
  421. megengine.Node = class {
  422. constructor(metadata, obj, values, allOprAndTensor, expr, state) {
  423. this.name = '';
  424. const type = metadata.type(obj.type);
  425. this.type = type ? { ...type } : { name: obj.type };
  426. this.type.name = this.type.name.replace(/V(\d+)$/, '');
  427. if (this.type.name.length > 4 && this.type.name.startsWith('__') && this.type.name.endsWith('__')) {
  428. this.type.name = this.type.name.substring(2, this.type.name.length - 2);
  429. }
  430. this.inputs = [];
  431. this.outputs = [];
  432. this.chain = [];
  433. this.attributes = [];
  434. const attributes = [];
  435. if (obj.inputs && obj.outputs) {
  436. const inputSchemas = this.type && this.type.inputs ? [...this.type.inputs] : [];
  437. for (let i = 0; i < obj.inputs.length; i++) {
  438. const inputOpr = allOprAndTensor.get(obj.inputs[i]);
  439. const schema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: `input${i === 0 ? '' : i.toString()}` };
  440. const argument = new megengine.Argument(schema.name, inputOpr.extraInfo.args);
  441. this.inputs.push(argument);
  442. }
  443. const outputSchemas = this.type && this.type.outputs ? [...this.type.outputs] : [];
  444. for (let i = 0; i < obj.outputs.length; i++) {
  445. const outputOpr = allOprAndTensor.get(obj.outputs[i]);
  446. const schema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: `output${i === 0 ? '' : i.toString()}` };
  447. const argument = new megengine.Argument(schema.name, outputOpr.extraInfo.args);
  448. this.outputs.push(argument);
  449. }
  450. if (obj.param) {
  451. for (const [name, value] of Object.entries(obj.param)) {
  452. if (value !== null) {
  453. const schema = metadata.attribute(obj.param.constructor.name, name);
  454. attributes.push([schema, name, value]);
  455. }
  456. }
  457. }
  458. }
  459. if (expr) {
  460. let inpIdx = 0;
  461. for (const i of expr.inputs) {
  462. if (i.__class__.__name__ !== 'ModuleNode') {
  463. const initializer = i.initializer === undefined ? null : i.initializer;
  464. const name = `input${inpIdx === 0 ? '' : inpIdx}`;
  465. const dtype = i._dtype ? i._dtype.__name__ : null;
  466. const shape = new megengine.TensorShape(i._shape);
  467. const type = new megengine.TensorType(dtype, shape);
  468. const value = values.map(i._fullname, type, initializer);
  469. const argument = new megengine.Argument(name, [value]);
  470. this.inputs.push(argument);
  471. inpIdx += 1;
  472. }
  473. }
  474. let outIdx = 0;
  475. let qparams = null;
  476. for (const o of expr.outputs) {
  477. if (o._qparams !== null) {
  478. qparams = o._qparams[1];
  479. }
  480. const name = `output${outIdx === 0 ? '' : outIdx}`;
  481. const dtype = o._dtype ? o._dtype.__name__ : null;
  482. const shape = new megengine.TensorShape(o._shape);
  483. const type = new megengine.TensorType(dtype, shape);
  484. const value = values.map(o._fullname, type, null);
  485. const argument = new megengine.Argument(name, [value]);
  486. this.outputs.push(argument);
  487. outIdx += 1;
  488. }
  489. if (qparams !== null) {
  490. state = state === null ? {} : state;
  491. state.scale = qparams.scale;
  492. state.zero_point = qparams.zero_point;
  493. state.quant_dtype_meta = qparams.dtype_meta;
  494. }
  495. if (state !== null) {
  496. for (const [key, obj] of Array.from(state)) {
  497. const isModule = (obj) => {
  498. return obj && (obj.state || obj._forward_pre_hooks);
  499. };
  500. const isTensor = (obj) => {
  501. return obj && obj.__class__ && obj.__class__.__module__ === 'megengine.tensor' && (obj.__class__.__name__ === 'Tensor' || obj.__class__.__name__ === 'Parameter');
  502. };
  503. if (!key.startsWith('_') && !isModule(obj)) {
  504. if (isTensor(obj)) {
  505. const tensor = obj;
  506. const dtype = tensor.dtype ? tensor.dtype.__name__ : null;
  507. const shape = new megengine.TensorShape(tensor.data.shape);
  508. const type = new megengine.TensorType(dtype, shape);
  509. const data = tensor.data.data;
  510. const initializer = new megengine.Tensor(key, type, data);
  511. const value = values.map('', type, initializer);
  512. const argument = new megengine.Argument(key, [value]);
  513. this.inputs.push(argument);
  514. } else {
  515. const value = obj === null ? 'None' : obj;
  516. attributes.push([null, key, value]);
  517. }
  518. }
  519. }
  520. }
  521. }
  522. this.attributes = attributes.map(([metadata, name, value]) => {
  523. value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  524. let type = metadata ? metadata.type : null;
  525. let visible = true;
  526. if (name === 'training') {
  527. visible = false;
  528. type = 'boolean';
  529. }
  530. if (megengine.schema) {
  531. if (megengine.schema.param[type]) {
  532. value = megengine.Utility.enum(megengine.schema.param, type, value);
  533. } else if (megengine.schema[type]) {
  534. value = megengine.Utility.enum(megengine.schema, type, value);
  535. } else if (megengine.schema.v2[type]) {
  536. value = megengine.Utility.enum(megengine.schema.v2, type, value);
  537. }
  538. }
  539. return new megengine.Argument(name, value, type, visible);
  540. });
  541. }
  542. };
  543. megengine.Tensor = class {
  544. constructor(name, type, data) {
  545. this.category = 'Tensor';
  546. this.name = name || '';
  547. this.type = type;
  548. this.values = data;
  549. }
  550. };
  551. megengine.TensorType = class {
  552. constructor(dataType, shape) {
  553. dataType = megengine.Utility.enum(megengine.schema, 'DTypeEnum', dataType);
  554. dataType = typeof dataType === 'string' ? dataType.toLowerCase() : dataType;
  555. megengine.TensorType._dataTypes = megengine.TensorType._dataTypes || new Map([
  556. ['bool', 'boolean'],
  557. ['byte', 'uint8'], ['quantizeds4asymm', 'uint8'], ['quantizeds8asymm', 'uint8'], ['uintb4', 'uint8'],
  558. ['quantizeds1', 'int8'], ['quantizeds4', 'int8'], ['quantizeds8', 'int8'], ['intb1', 'int8'], ['intb2', 'int8'], ['intb4', 'int8'], ['qint8', 'int8'],
  559. ['quantizeds16', 'int16'],
  560. ['quantizeds32', 'int32']
  561. ]);
  562. this.dataType = megengine.TensorType._dataTypes.get(dataType) || dataType;
  563. this.shape = shape;
  564. }
  565. equals(obj) {
  566. return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
  567. }
  568. toString() {
  569. return this.dataType + this.shape.toString();
  570. }
  571. };
  572. megengine.TensorShape = class {
  573. constructor(dimensions) {
  574. this.dimensions = Array.from(dimensions || []);
  575. }
  576. equals(obj) {
  577. return obj && Array.isArray(obj.dimensions) &&
  578. Array.isArray(this.dimensions) && this.dimensions.length === obj.dimensions.length
  579. && obj.dimensions.every((value, index) => this.dimensions[index] === value);
  580. }
  581. toString() {
  582. if (this.dimensions && this.dimensions.length > 0) {
  583. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  584. }
  585. return '';
  586. }
  587. };
  588. megengine.Utility = class {
  589. static enum(schema, name, value) {
  590. const type = name && schema ? schema[name] : undefined;
  591. if (type) {
  592. megengine.Utility._enums = megengine.Utility._enums || new Map();
  593. if (!megengine.Utility._enums.has(name)) {
  594. const entries = new Map(Object.entries(type).map(([key, value]) => [value, key]));
  595. megengine.Utility._enums.set(name, entries);
  596. }
  597. const map = megengine.Utility._enums.get(name);
  598. if (map.has(value)) {
  599. return map.get(value);
  600. }
  601. }
  602. return value;
  603. }
  604. };
  605. megengine.Error = class extends Error {
  606. constructor(message) {
  607. super(message);
  608. this.name = 'Error loading MegEngine model.';
  609. }
  610. };
  611. export const ModelFactory = megengine.ModelFactory;