megengine.js 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. // Experimental
  2. import * as flatbuffers from './flatbuffers.js';
  3. const megengine = {};
  4. megengine.ModelFactory = class {
  5. async match(context) {
  6. const stream = context.stream;
  7. if (stream && stream.length >= 12) {
  8. let buffer = stream.peek(12);
  9. const tag = String.fromCharCode.apply(null, buffer);
  10. const position = tag.startsWith('mgbtest0') ? 12 : 0;
  11. if (stream.length > (position + 12)) {
  12. buffer = stream.peek(24).slice(position, position + 12);
  13. const size = buffer[0] + (buffer[1] << 8) + (buffer[2] << 16) + (buffer[3] << 24);
  14. if (position > 0 || size === (stream.length - position - 4)) {
  15. const reader = flatbuffers.BinaryReader.open(stream, position + 4);
  16. if (reader.identifier === 'mgv2') {
  17. return context.set('megengine.mge', reader);
  18. }
  19. }
  20. }
  21. for (const value of ['mgb0001', 'mgb0000a', 'MGBS', 'MGBC']) {
  22. if (tag.startsWith(value)) {
  23. return context.set(`megengine.${value}`);
  24. }
  25. }
  26. }
  27. const obj = await context.peek('pkl');
  28. if (obj && obj.__class__ && obj.__class__.__module__ === 'megengine.traced_module.traced_module' && obj.__class__.__name__ === 'TracedModule') {
  29. return context.set('megengine.tm');
  30. }
  31. return null;
  32. }
  33. async open(context) {
  34. const metadata = await context.metadata('megengine-metadata.json');
  35. switch (context.type) {
  36. case 'megengine.tm': {
  37. const obj = await context.peek('pkl');
  38. return new megengine.Model(metadata, obj, context.type);
  39. }
  40. case 'megengine.mge': {
  41. megengine.schema = await context.require('./megengine-schema');
  42. megengine.schema = megengine.schema.mgb.serialization.fbs;
  43. let model = null;
  44. try {
  45. const reader = context.value;
  46. model = megengine.schema.v2.Model.create(reader);
  47. } catch (error) {
  48. const message = error && error.message ? error.message : error.toString();
  49. throw new megengine.Error(`File format is not megengine.Model (${message.replace(/\.$/, '')}).`);
  50. }
  51. return new megengine.Model(metadata, model, context.type);
  52. }
  53. default: {
  54. throw new megengine.Error(`Unsupported MegEngine format '${context.type.replace(/^megengine\./, '')}'.`);
  55. }
  56. }
  57. }
  58. };
  59. megengine.Model = class {
  60. constructor(metadata, obj, type) {
  61. this.format = 'MegEngine';
  62. if (type === 'megengine.tm') {
  63. this.format += obj.dump_info && obj.dump_info.has('version') ? ` v${obj.dump_info.get('version')}` : '';
  64. } else if (type === 'megengine.mge') {
  65. this.format += ` Mge${obj.model_version ? ` v${obj.model_version}` : ''}`;
  66. }
  67. this.modules = [new megengine.Graph(metadata, obj)];
  68. }
  69. };
  70. megengine.Graph = class {
  71. constructor(metadata, obj) {
  72. this.name = '';
  73. this.nodes = [];
  74. this.inputs = [];
  75. this.outputs = [];
  76. const values = new Map();
  77. values.map = (name, type, tensor, quantization) => {
  78. type = type || null;
  79. tensor = tensor || null;
  80. if (tensor) {
  81. return new megengine.Value(name, type, tensor, quantization);
  82. }
  83. if (!values.has(name)) {
  84. values.set(name, new megengine.Value(name, type, tensor, quantization));
  85. } else if ((type && !type.equals(values.get(name).type)) || tensor) {
  86. throw new megengine.Error(`Duplicate value '${name}'.`);
  87. }
  88. return values.get(name);
  89. };
  90. const loadGraph = (tmodule, igraph, context, namePrefix, metadata, isRoot) => {
  91. const expressions = igraph._exprs;
  92. const getTensorType = (dtype, shape) => {
  93. dtype = dtype ? dtype.__name__ : null;
  94. return new megengine.TensorType(dtype, new megengine.TensorShape(shape));
  95. };
  96. if (isRoot) {
  97. for (const node of igraph._inputs) {
  98. if (node.__class__.__name__ !== 'ModuleNode') {
  99. const type = getTensorType(node._dtype, node._shape);
  100. const value = values.map(node._name, type, null);
  101. const argument = new megengine.Argument(node._name, [value]);
  102. this.inputs.push(argument);
  103. }
  104. }
  105. for (const node of igraph._outputs) {
  106. const type = getTensorType(node._dtype, node._shape);
  107. const value = values.map(node._name, type, null);
  108. const argument = new megengine.Argument(node._name, [value]);
  109. this.outputs.push(argument);
  110. }
  111. }
  112. const parseGetAttr = (module, expression) => {
  113. let names = expression.name.split('.');
  114. while (expression.inputs[0].expr.__class__.__name__ === 'GetAttr') {
  115. expression = expression.inputs[0].expr;
  116. names = expression.name.split('.').concat(names);
  117. }
  118. let obj = module;
  119. for (const name of names) {
  120. obj = obj[name];
  121. }
  122. return obj;
  123. };
  124. const parseArgs = (args, kwargs, meta) => {
  125. const state = {};
  126. let argIdx = 0;
  127. const processArgs = (inp, startIdx) => {
  128. while (typeof inp === 'string' && inp.indexOf('Tensor') !== -1) {
  129. inp = inp.replace('Tensor', `input${startIdx === 0 ? '' : startIdx}`);
  130. startIdx += 1;
  131. }
  132. return [inp, startIdx];
  133. };
  134. const formatTreeDef = (obj) => {
  135. if (obj.__class__.__name__ !== 'TreeDef' && obj.__class__.__name__ !== 'LeafDef') {
  136. throw new megengine.Error(`Invalid argument '${obj.__class__.__name__}'.`);
  137. }
  138. if (obj.__class__.__name__ === 'TreeDef') {
  139. const type = typeof obj.type === 'string' ? obj.type.split('.').slice(-1)[0] : obj.type.__name__;
  140. const list = obj.children_defs.map((child) => formatTreeDef(child));
  141. switch (type) {
  142. case 'tuple': {
  143. return `(${list.join(',')})`;
  144. }
  145. case 'slice': {
  146. return list.join(':');
  147. }
  148. case 'list': {
  149. return `[${list.join(',')}]`;
  150. }
  151. case 'dict': {
  152. let content = '';
  153. for (let i = 0; i < this.children_defs.length; i++) {
  154. content += `${this.aux_data[i]}:${list[i]}`;
  155. }
  156. return `{${content}}`;
  157. }
  158. default: {
  159. return `${type}(${list.join(',')})`;
  160. }
  161. }
  162. }
  163. if (obj.const_val !== null) {
  164. return obj.const_val;
  165. } else if (obj.type[0].__module__ !== undefined) {
  166. return obj.type[0].__name__;
  167. }
  168. return 'None';
  169. };
  170. let inpIdx = 0;
  171. for (const arg of args.children_defs) {
  172. let name = '';
  173. if (meta.attributes === undefined || (meta.attributes.length !== args.children_defs.length && meta.varargs === null)) {
  174. name = `arg${argIdx}`;
  175. } else if (argIdx < meta.attributes.length) {
  176. name = meta.attributes[argIdx].name;
  177. } else {
  178. name = meta.varargs + (argIdx - meta.attributes.length);
  179. }
  180. const [value, index] = processArgs(formatTreeDef(arg), inpIdx);
  181. state[name] = value;
  182. inpIdx = index;
  183. argIdx += 1;
  184. }
  185. for (let i = 0; i < kwargs.children_defs.length; i++) {
  186. const [value, index] = processArgs(formatTreeDef(kwargs.children_defs[i]), inpIdx);
  187. state[kwargs.aux_data[i]] = value;
  188. inpIdx = index;
  189. }
  190. return state;
  191. };
  192. const getName = (context, name) => {
  193. let rst = name;
  194. while (context.get(rst) !== undefined) {
  195. if (rst === context.get(rst)) {
  196. return rst;
  197. }
  198. rst = context.get(rst);
  199. }
  200. return rst;
  201. };
  202. const getFullName = (prefix, name) => {
  203. return prefix === '' ? name : `${prefix}_${name}`;
  204. };
  205. for (const expression of expressions) {
  206. const type = expression.__class__.__name__;
  207. for (const input of expression.inputs) {
  208. input._fullname = getName(context, getFullName(namePrefix, input._name));
  209. }
  210. for (const output of expression.outputs) {
  211. output._fullname = getName(context, getFullName(namePrefix, output._name));
  212. }
  213. switch (type) {
  214. case 'Input': {
  215. break;
  216. }
  217. case 'GetAttr': {
  218. if (expression.outputs[0].__class__.__name__ === 'TensorNode') {
  219. const tensor = parseGetAttr(tmodule, expression);
  220. const type = getTensorType(tensor.dtype, tensor.data.shape);
  221. const data = tensor.data.data;
  222. expression.outputs[0].initializer = new megengine.Tensor(expression.name, type, data);
  223. }
  224. break;
  225. }
  226. case 'Constant': {
  227. if (expression.outputs[0].__class__.__name__ === 'TensorNode') {
  228. const tensor = expression.value;
  229. const type = getTensorType(tensor.dtype, tensor.data.shape);
  230. const data = tensor.data.data;
  231. expression.outputs[0].initializer = new megengine.Tensor('', type, data);
  232. }
  233. break;
  234. }
  235. case 'CallMethod': {
  236. const obj = { 'name': '' };
  237. let state = {};
  238. if (expression.method === '__call__') {
  239. const module = parseGetAttr(tmodule, expression.inputs[0].expr);
  240. const getModuleType = (obj) => {
  241. if (obj.module !== undefined) {
  242. return `${obj.module[0]}.${obj.module[1]}`;
  243. }
  244. return `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  245. };
  246. const moduleType = module.__class__.__name__ === 'TracedModule' ? 'TracedModule' : getModuleType(module);
  247. if (moduleType === 'TracedModule') {
  248. const moduleName = expression.outputs[0]._name.endsWith("_out") ? expression.outputs[0]._name.substring(0, expression.outputs[0]._name.length - 4) : expression.outputs[0]._name;
  249. const prefix = getFullName(namePrefix, moduleName);
  250. const internalGraph = module.argdef_graph_map.get(expression.arg_def);
  251. for (let i = 0; i < expression.inputs.length; i++) {
  252. const actualName = getFullName(namePrefix, expression.inputs[i]._name);
  253. const internalName = getFullName(prefix, internalGraph._inputs[i]._name);
  254. context.set(internalName, actualName);
  255. }
  256. for (let i = 0; i < expression.outputs.length; i++) {
  257. const actualName = getFullName(namePrefix, expression.outputs[i]._name);
  258. const internalName = getFullName(prefix, internalGraph._outputs[i]._name);
  259. if (context.get(internalName) === undefined) {
  260. context.set(internalName, actualName);
  261. } else {
  262. context.set(actualName, context.get(internalName));
  263. }
  264. }
  265. loadGraph(module, internalGraph, context, prefix, metadata, false);
  266. continue;
  267. }
  268. obj.type = moduleType;
  269. state = module.__class__.__name__ === 'TracedModule' ? module : module.state;
  270. if (state === undefined) {
  271. state = module;
  272. }
  273. } else {
  274. obj.type = expression.method;
  275. const [args, kwargs] = expression.arg_def.children_defs;
  276. const schema = metadata.type(expression.method);
  277. state = parseArgs(args, kwargs, schema);
  278. }
  279. const node = new megengine.Node(metadata, obj, values, null, expression, state);
  280. this.nodes.push(node);
  281. break;
  282. }
  283. case 'CallFunction': {
  284. const getFunctionType = (obj) => {
  285. if (obj.func.__module__ !== undefined) {
  286. return `${obj.func.__module__}.${obj.func.__name__}`;
  287. }
  288. return `${obj.func[0]}.${obj.func[1]}`;
  289. };
  290. const func = getFunctionType(expression);
  291. const item = { 'name': '', 'type': func };
  292. const [args, kwargs] = expression.arg_def.children_defs;
  293. const schema = metadata.type(func);
  294. const state = parseArgs(args, kwargs, schema);
  295. const node = new megengine.Node(metadata, item, values, null, expression, state);
  296. this.nodes.push(node);
  297. break;
  298. }
  299. case 'Apply': {
  300. const opdef = expression.opdef_state ? expression.opdef_state.get('opdef_type') : expression.opdef.type;
  301. const item = { 'name': '', 'type': `${opdef.__module__}.${opdef.__name__}` };
  302. const node = new megengine.Node(metadata, item, values, null, expression, expression.opdef_state);
  303. this.nodes.push(node);
  304. break;
  305. }
  306. default: {
  307. break;
  308. }
  309. }
  310. }
  311. };
  312. if (obj.argdef_graph_map) {
  313. const [graph] = Array.from(obj.argdef_graph_map.values());
  314. loadGraph(obj, graph, new Map(), '', metadata, true);
  315. return;
  316. }
  317. const extraInfoNameset = new Set();
  318. const getExtraInfo = (opr) => {
  319. let name = opr.name;
  320. let repeatIdx = 0;
  321. while (extraInfoNameset.has(name)) {
  322. for (const id of opr.inputs) {
  323. name = `${name}[${id}]`;
  324. }
  325. name += repeatIdx;
  326. repeatIdx += 1;
  327. }
  328. extraInfoNameset.add(name);
  329. const type = opr.type.replace(/V(\d+)$/, '');
  330. const args = [];
  331. if (opr.tensors.length > 0) {
  332. const [tensor] = opr.tensors;
  333. const type = new megengine.TensorType(tensor.dtype.type, new megengine.TensorShape(tensor.shape));
  334. const data = tensor.data.byteLength === 0 ? undefined : tensor.data.slice(0);
  335. const initializer = opr.type === 'Host2DeviceCopy' ? undefined : new megengine.Tensor('', type, data);
  336. const quantization = tensor.dtype.param ? { scale: tensor.dtype.param.scale, zeroPoint: tensor.dtype.param.zero_point } : null;
  337. const value = values.map(name, type, initializer, quantization);
  338. args.push(value);
  339. } else {
  340. const type = opr.shape ? new megengine.TensorType('?', new megengine.TensorShape(opr.shape)) : null;
  341. const value = values.map(name, type);
  342. args.push(value);
  343. }
  344. return { name, type, args };
  345. };
  346. const getAllOprAndTensor = (oprs) => {
  347. const allOprAndTensor = new Map();
  348. for (const opr of oprs) {
  349. if (opr.type === 'MultipleDeviceTensorWithFormatHolder' || opr.outputs.length > 1) {
  350. if (opr.type === 'MultipleDeviceTensorWithFormatHolder' || opr.type === 'MultipleDeviceTensorHolder') {
  351. opr.type = 'ImmutableTensor';
  352. }
  353. for (let id = 0; id < opr.outputs.length; id++) {
  354. const keyId = opr.outputs[id];
  355. const name = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].name : String(keyId);
  356. const type = opr.type;
  357. const tensors = opr.tensors.length ? [opr.tensors[id]] : [];
  358. const onlyShape = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].shape : [];
  359. allOprAndTensor.set(keyId, { name, type, tensors, shape: onlyShape, inputs: opr.inputs, outputs: opr.outputs });
  360. const _opr = allOprAndTensor.get(keyId);
  361. _opr.extraInfo = getExtraInfo(_opr);
  362. }
  363. } else {
  364. const [keyId] = opr.outputs;
  365. opr.name = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].name : String(keyId);
  366. if (obj.middle_tensors[keyId] && obj.middle_tensors[keyId].shape) {
  367. opr.shape = obj.middle_tensors[keyId].shape;
  368. }
  369. allOprAndTensor.set(keyId, opr);
  370. const _opr = allOprAndTensor.get(keyId);
  371. _opr.extraInfo = getExtraInfo(_opr);
  372. }
  373. }
  374. return allOprAndTensor;
  375. };
  376. const allOprAndTensor = getAllOprAndTensor(obj.oprs);
  377. for (const op of Array.from(allOprAndTensor.values())) {
  378. if (op.type === 'Host2DeviceCopy') {
  379. const argument = new megengine.Argument('input', op.extraInfo.args);
  380. this.inputs.push(argument);
  381. } else if (op.type !== 'ImmutableTensor') {
  382. this.nodes.push(new megengine.Node(metadata, op, values, allOprAndTensor));
  383. }
  384. }
  385. for (let i = 0; i < obj.output_vars_idx.length; i++) {
  386. const id = obj.output_vars_idx[i].compact_id;
  387. const out_type = `output${i === 0 ? '' : i}`;
  388. const argument = new megengine.Argument(out_type, allOprAndTensor.get(id).extraInfo.args);
  389. this.outputs.push(argument);
  390. }
  391. }
  392. };
  393. megengine.Argument = class {
  394. constructor(name, value, type = null, visible = true) {
  395. this.name = name;
  396. this.value = value;
  397. this.type = type;
  398. this.visible = visible;
  399. }
  400. };
  401. megengine.Value = class {
  402. constructor(name, type, initializer, quantization) {
  403. if (typeof name !== 'string') {
  404. throw new megengine.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  405. }
  406. this.name = name;
  407. this.type = !type && initializer ? initializer.type : type;
  408. this.initializer = initializer;
  409. if (quantization && ((quantization.scale !== undefined && quantization.scale !== 1) || (quantization.zeroPoint !== undefined && quantization.zeroPoint !== 0))) {
  410. this.quantization = {
  411. type: 'linear',
  412. scale: [quantization.scale],
  413. offset: [quantization.zeroPoint]
  414. };
  415. }
  416. }
  417. };
  418. megengine.Node = class {
  419. constructor(metadata, obj, values, allOprAndTensor, expr, state) {
  420. this.name = '';
  421. const type = metadata.type(obj.type);
  422. this.type = type ? { ...type } : { name: obj.type };
  423. this.type.name = this.type.name.replace(/V(\d+)$/, '');
  424. if (this.type.name.length > 4 && this.type.name.startsWith('__') && this.type.name.endsWith('__')) {
  425. this.type.name = this.type.name.substring(2, this.type.name.length - 2);
  426. }
  427. this.inputs = [];
  428. this.outputs = [];
  429. this.chain = [];
  430. this.attributes = [];
  431. const attributes = [];
  432. if (obj.inputs && obj.outputs) {
  433. const inputSchemas = this.type && this.type.inputs ? [...this.type.inputs] : [];
  434. for (let i = 0; i < obj.inputs.length; i++) {
  435. const inputOpr = allOprAndTensor.get(obj.inputs[i]);
  436. const schema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: `input${i === 0 ? '' : i.toString()}` };
  437. const argument = new megengine.Argument(schema.name, inputOpr.extraInfo.args);
  438. this.inputs.push(argument);
  439. }
  440. const outputSchemas = this.type && this.type.outputs ? [...this.type.outputs] : [];
  441. for (let i = 0; i < obj.outputs.length; i++) {
  442. const outputOpr = allOprAndTensor.get(obj.outputs[i]);
  443. const schema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: `output${i === 0 ? '' : i.toString()}` };
  444. const argument = new megengine.Argument(schema.name, outputOpr.extraInfo.args);
  445. this.outputs.push(argument);
  446. }
  447. if (obj.param) {
  448. for (const [name, value] of Object.entries(obj.param)) {
  449. if (value !== null) {
  450. const schema = metadata.attribute(obj.param.constructor.name, name);
  451. attributes.push([schema, name, value]);
  452. }
  453. }
  454. }
  455. }
  456. if (expr) {
  457. let inpIdx = 0;
  458. for (const i of expr.inputs) {
  459. if (i.__class__.__name__ !== 'ModuleNode') {
  460. const initializer = i.initializer === undefined ? null : i.initializer;
  461. const name = `input${inpIdx === 0 ? '' : inpIdx}`;
  462. const dtype = i._dtype ? i._dtype.__name__ : null;
  463. const shape = new megengine.TensorShape(i._shape);
  464. const type = new megengine.TensorType(dtype, shape);
  465. const value = values.map(i._fullname, type, initializer);
  466. const argument = new megengine.Argument(name, [value]);
  467. this.inputs.push(argument);
  468. inpIdx += 1;
  469. }
  470. }
  471. let outIdx = 0;
  472. let qparams = null;
  473. for (const o of expr.outputs) {
  474. if (o._qparams !== null) {
  475. qparams = o._qparams[1];
  476. }
  477. const name = `output${outIdx === 0 ? '' : outIdx}`;
  478. const dtype = o._dtype ? o._dtype.__name__ : null;
  479. const shape = new megengine.TensorShape(o._shape);
  480. const type = new megengine.TensorType(dtype, shape);
  481. const value = values.map(o._fullname, type, null);
  482. const argument = new megengine.Argument(name, [value]);
  483. this.outputs.push(argument);
  484. outIdx += 1;
  485. }
  486. if (qparams !== null) {
  487. state = state === null ? {} : state;
  488. state.scale = qparams.scale;
  489. state.zero_point = qparams.zero_point;
  490. state.quant_dtype_meta = qparams.dtype_meta;
  491. }
  492. if (state !== null) {
  493. for (const [key, obj] of Array.from(state)) {
  494. const isModule = (obj) => {
  495. return obj && (obj.state || obj._forward_pre_hooks);
  496. };
  497. const isTensor = (obj) => {
  498. return obj && obj.__class__ && obj.__class__.__module__ === 'megengine.tensor' && (obj.__class__.__name__ === 'Tensor' || obj.__class__.__name__ === 'Parameter');
  499. };
  500. if (!key.startsWith('_') && !isModule(obj)) {
  501. if (isTensor(obj)) {
  502. const tensor = obj;
  503. const dtype = tensor.dtype ? tensor.dtype.__name__ : null;
  504. const shape = new megengine.TensorShape(tensor.data.shape);
  505. const type = new megengine.TensorType(dtype, shape);
  506. const data = tensor.data.data;
  507. const initializer = new megengine.Tensor(key, type, data);
  508. const value = values.map('', type, initializer);
  509. const argument = new megengine.Argument(key, [value]);
  510. this.inputs.push(argument);
  511. } else {
  512. const value = obj === null ? 'None' : obj;
  513. attributes.push([null, key, value]);
  514. }
  515. }
  516. }
  517. }
  518. }
  519. this.attributes = attributes.map(([metadata, name, value]) => {
  520. value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  521. let type = metadata ? metadata.type : null;
  522. let visible = true;
  523. if (name === 'training') {
  524. visible = false;
  525. type = 'boolean';
  526. }
  527. if (megengine.schema) {
  528. if (megengine.schema.param[type]) {
  529. value = megengine.Utility.enum(megengine.schema.param, type, value);
  530. } else if (megengine.schema[type]) {
  531. value = megengine.Utility.enum(megengine.schema, type, value);
  532. } else if (megengine.schema.v2[type]) {
  533. value = megengine.Utility.enum(megengine.schema.v2, type, value);
  534. }
  535. }
  536. return new megengine.Argument(name, value, type, visible);
  537. });
  538. }
  539. };
  540. megengine.Tensor = class {
  541. constructor(name, type, data) {
  542. this.category = 'Tensor';
  543. this.name = name || '';
  544. this.type = type;
  545. this.values = data;
  546. }
  547. };
  548. megengine.TensorType = class {
  549. constructor(dataType, shape) {
  550. dataType = megengine.Utility.enum(megengine.schema, 'DTypeEnum', dataType);
  551. dataType = typeof dataType === 'string' ? dataType.toLowerCase() : dataType;
  552. megengine.TensorType._dataTypes = megengine.TensorType._dataTypes || new Map([
  553. ['bool', 'boolean'],
  554. ['byte', 'uint8'], ['quantizeds4asymm', 'uint8'], ['quantizeds8asymm', 'uint8'], ['uintb4', 'uint8'],
  555. ['quantizeds1', 'int8'], ['quantizeds4', 'int8'], ['quantizeds8', 'int8'], ['intb1', 'int8'], ['intb2', 'int8'], ['intb4', 'int8'], ['qint8', 'int8'],
  556. ['quantizeds16', 'int16'],
  557. ['quantizeds32', 'int32']
  558. ]);
  559. this.dataType = megengine.TensorType._dataTypes.get(dataType) || dataType;
  560. this.shape = shape;
  561. }
  562. equals(obj) {
  563. return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
  564. }
  565. toString() {
  566. return this.dataType + this.shape.toString();
  567. }
  568. };
  569. megengine.TensorShape = class {
  570. constructor(dimensions) {
  571. this.dimensions = Array.from(dimensions || []);
  572. }
  573. equals(obj) {
  574. return obj && Array.isArray(obj.dimensions) &&
  575. Array.isArray(this.dimensions) && this.dimensions.length === obj.dimensions.length
  576. && obj.dimensions.every((value, index) => this.dimensions[index] === value);
  577. }
  578. toString() {
  579. if (this.dimensions && this.dimensions.length > 0) {
  580. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  581. }
  582. return '';
  583. }
  584. };
  585. megengine.Utility = class {
  586. static enum(schema, name, value) {
  587. const type = name && schema ? schema[name] : undefined;
  588. if (type) {
  589. megengine.Utility._enums = megengine.Utility._enums || new Map();
  590. if (!megengine.Utility._enums.has(name)) {
  591. const entries = new Map(Object.entries(type).map(([key, value]) => [value, key]));
  592. megengine.Utility._enums.set(name, entries);
  593. }
  594. const map = megengine.Utility._enums.get(name);
  595. if (map.has(value)) {
  596. return map.get(value);
  597. }
  598. }
  599. return value;
  600. }
  601. };
  602. megengine.Error = class extends Error {
  603. constructor(message) {
  604. super(message);
  605. this.name = 'Error loading MegEngine model.';
  606. }
  607. };
  608. export const ModelFactory = megengine.ModelFactory;