megengine.js 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. // Experimental
  2. import * as flatbuffers from './flatbuffers.js';
  3. const megengine = {};
  4. megengine.ModelFactory = class {
  5. match(context) {
  6. const stream = context.stream;
  7. if (stream && stream.length >= 12) {
  8. let buffer = stream.peek(12);
  9. const tag = String.fromCharCode.apply(null, buffer);
  10. const position = tag.startsWith('mgbtest0') ? 12 : 0;
  11. if (stream.length > (position + 12)) {
  12. buffer = stream.peek(24).slice(position, position + 12);
  13. const size = buffer[0] + (buffer[1] << 8) + (buffer[2] << 16) + (buffer[3] << 24);
  14. if (position > 0 || size === (stream.length - position - 4)) {
  15. const reader = flatbuffers.BinaryReader.open(stream, position + 4);
  16. if (reader.identifier === 'mgv2') {
  17. context.type = 'megengine.mge';
  18. context.target = reader;
  19. return;
  20. }
  21. }
  22. }
  23. for (const value of [ 'mgb0001', 'mgb0000a', 'MGBS', 'MGBC' ]) {
  24. if (tag.startsWith(value)) {
  25. context.type = `megengine.${value}`;
  26. return;
  27. }
  28. }
  29. }
  30. const obj = context.peek('pkl');
  31. if (obj && obj.__class__ && obj.__class__.__module__ === 'megengine.traced_module.traced_module' && obj.__class__.__name__ === 'TracedModule') {
  32. context.type = 'megengine.tm';
  33. return;
  34. }
  35. }
  36. async open(context) {
  37. const metadata = await context.metadata('megengine-metadata.json');
  38. switch (context.type) {
  39. case 'megengine.tm': {
  40. const obj = context.peek('pkl');
  41. return new megengine.Model(metadata, obj, context.type);
  42. }
  43. case 'megengine.mge': {
  44. megengine.schema = await context.require('./megengine-schema');
  45. megengine.schema = megengine.schema.mgb.serialization.fbs;
  46. let model = null;
  47. try {
  48. const reader = context.target;
  49. model = megengine.schema.v2.Model.create(reader);
  50. } catch (error) {
  51. const message = error && error.message ? error.message : error.toString();
  52. throw new megengine.Error(`File format is not megengine.Model (${message.replace(/\.$/, '')}).`);
  53. }
  54. return new megengine.Model(metadata, model, context.type);
  55. }
  56. default: {
  57. throw new megengine.Error(`Unsupported MegEngine format '${context.type.replace(/^megengine\./, '')}'.`);
  58. }
  59. }
  60. }
  61. };
  62. megengine.Model = class {
  63. constructor(metadata, obj, type) {
  64. this.format = 'MegEngine';
  65. if (type === 'megengine.tm') {
  66. this.format += (obj.dump_info && obj.dump_info.version ? ` v${obj.dump_info.version}` : '');
  67. } else if (type === 'megengine.mge') {
  68. this.format += ` Mge${obj.model_version ? ` v${obj.model_version}` : ''}`;
  69. }
  70. this.graphs = [ new megengine.Graph(metadata, obj) ];
  71. }
  72. };
  73. megengine.Graph = class {
  74. constructor(metadata, obj) {
  75. this.name = '';
  76. this.nodes = [];
  77. this.inputs = [];
  78. this.outputs = [];
  79. const values = new Map();
  80. const value = (name, type, tensor) => {
  81. if (tensor && name.length === 0) {
  82. return new megengine.Value(name, type || null, tensor);
  83. }
  84. if (!values.has(name)) {
  85. values.set(name, new megengine.Value(name, type || null, tensor || null));
  86. } else if ((type && !type.equals(values.get(name).type)) || tensor) {
  87. throw new megengine.Error(`Duplicate value '${name}'.`);
  88. }
  89. return values.get(name);
  90. };
  91. const loadGraph = (tmodule, igraph, context, namePrefix, metadata, isRoot) => {
  92. const expressions = igraph._exprs;
  93. const getTensorType = (dtype, shape) => {
  94. dtype = dtype ? dtype.__name__ : null;
  95. return new megengine.TensorType(dtype, new megengine.TensorShape(shape));
  96. };
  97. const getOpNode = (metadata, item, expr, state) => {
  98. const node = new megengine.Node(metadata, item);
  99. let inpIdx = 0;
  100. for (const i of expr.inputs) {
  101. if (i.__class__.__name__ !== 'ModuleNode') {
  102. const initializer = i.initializer !== undefined ? i.initializer : null;
  103. const name = `inp${inpIdx}`;
  104. const type = getTensorType(i._dtype, i._shape);
  105. const argument = new megengine.Argument(name, [ value(i._fullname, type, initializer) ]);
  106. node.inputs.push(argument);
  107. inpIdx += 1;
  108. }
  109. }
  110. const outIdx = 0;
  111. let qparams = null;
  112. for (const o of expr.outputs) {
  113. if (o._qparams !== null) {
  114. /* eslint-disable prefer-destructuring */
  115. qparams = o._qparams[1];
  116. /* eslint-enable prefer-destructuring */
  117. }
  118. const type = getTensorType(o._dtype, o._shape);
  119. const argument = new megengine.Argument(`out${outIdx}`, [ value(o._fullname, type, null) ]);
  120. node.outputs.push(argument);
  121. }
  122. if (qparams !== null) {
  123. state = state === null ? {} : state;
  124. state.scale = qparams.scale;
  125. state.zero_point = qparams.zero_point;
  126. state.quant_dtype_meta = qparams.dtype_meta;
  127. }
  128. if (state !== null) {
  129. for (const key in state) {
  130. const isModule = (obj) => {
  131. return obj && (obj.state || obj._forward_pre_hooks);
  132. };
  133. const isTensor = (obj) => {
  134. return obj && obj.__class__ && obj.__class__.__module__ == 'megengine.tensor' && (obj.__class__.__name__ === 'Tensor' || obj.__class__.__name__ === 'Parameter');
  135. };
  136. if (!key.startsWith('_') && !isModule(state[key])) {
  137. if (!isTensor(state[key])) {
  138. const attribute = new megengine.Attribute(null, key, state[key] !== null ? state[key] : 'None');
  139. node.attributes.push(attribute);
  140. } else {
  141. const tensor = state[key];
  142. const type = getTensorType(tensor.dtype, tensor.data.shape);
  143. const data = tensor.data.data;
  144. const initializer = new megengine.Tensor(key, type, data);
  145. const argument = new megengine.Argument(key, [ value('', type, initializer) ]);
  146. node.inputs.push(argument);
  147. }
  148. }
  149. }
  150. }
  151. return node;
  152. };
  153. if (isRoot) {
  154. for (const node of igraph._inputs) {
  155. if (node.__class__.__name__ !== 'ModuleNode') {
  156. const type = getTensorType(node._dtype, node._shape);
  157. const argument = new megengine.Argument(node._name, [ value(node._name, type, null) ]);
  158. this.inputs.push(argument);
  159. }
  160. }
  161. for (const node of igraph._outputs) {
  162. const type = getTensorType(node._dtype, node._shape);
  163. const argument = new megengine.Argument(node._name, [ value(node._name, type, null) ]);
  164. this.outputs.push(argument);
  165. }
  166. }
  167. const parseGetAttr = (module, expression) => {
  168. let names = expression.name.split('.');
  169. while (expression.inputs[0].expr.__class__.__name__ === 'GetAttr') {
  170. expression = expression.inputs[0].expr;
  171. names = expression.name.split('.').concat(names);
  172. }
  173. let obj = module;
  174. for (const name of names) {
  175. obj = obj[name];
  176. }
  177. return obj;
  178. };
  179. const parseArgs = (args, kwargs, meta) => {
  180. const state = {};
  181. let argIdx = 0;
  182. const processArgs = (inp, startIdx) => {
  183. while (typeof inp === 'string' && inp.indexOf('Tensor') !== -1) {
  184. inp = inp.replace('Tensor', `inp${startIdx}`);
  185. startIdx += 1;
  186. }
  187. return [ inp, startIdx ];
  188. };
  189. const formatTreeDef = (obj) => {
  190. if (obj.__class__.__name__ !== 'TreeDef' && obj.__class__.__name__ !== 'LeafDef') {
  191. throw new megengine.Error(`Invalid argument '${obj.__class__.__name__}'.`);
  192. }
  193. if (obj.__class__.__name__ === 'TreeDef') {
  194. const type = typeof obj.type !== 'string' ? obj.type.__name__ : obj.type.split('.').slice(-1)[0];
  195. const list = obj.children_defs.map((child) => formatTreeDef(child));
  196. switch (type) {
  197. case 'tuple': {
  198. return `(${list.join(',')})`;
  199. }
  200. case 'slice': {
  201. return list.join(':');
  202. }
  203. case 'list': {
  204. return `[${list.join(',')}]`;
  205. }
  206. case 'dict': {
  207. let content = '';
  208. for (let i = 0; i < this.children_defs.length; i++) {
  209. content += `${this.aux_data[i]}:${list[i]}`;
  210. }
  211. return `{${content}}`;
  212. }
  213. default: {
  214. return `${type}(${list.join(',')})`;
  215. }
  216. }
  217. }
  218. if (obj.const_val !== null) {
  219. return obj.const_val;
  220. } else if (obj.type[0].__module__ !== undefined) {
  221. return obj.type[0].__name__;
  222. }
  223. return 'None';
  224. };
  225. let inpIdx = 0;
  226. for (const arg of args.children_defs) {
  227. let name = '';
  228. if (meta.attributes === undefined || (meta.attributes.length !== args.children_defs.length && meta.varargs === null)) {
  229. name = `arg${argIdx}`;
  230. } else if (argIdx < meta.attributes.length) {
  231. name = meta.attributes[argIdx].name;
  232. } else {
  233. name = meta.varargs + (argIdx - meta.attributes.length);
  234. }
  235. const [value, index] = processArgs(formatTreeDef(arg), inpIdx);
  236. state[name] = value;
  237. inpIdx = index;
  238. argIdx += 1;
  239. }
  240. for (let i = 0; i < kwargs.children_defs.length; i++) {
  241. const [value, index] = processArgs(formatTreeDef(kwargs.children_defs[i]), inpIdx);
  242. state[kwargs.aux_data[i]] = value;
  243. inpIdx = index;
  244. }
  245. return state;
  246. };
  247. const getName = (context, name) => {
  248. let rst = name;
  249. while (context.get(rst) !== undefined) {
  250. if (rst === context.get(rst)) {
  251. return rst;
  252. }
  253. rst = context.get(rst);
  254. }
  255. return rst;
  256. };
  257. const getFullName = (prefix, name) => {
  258. return prefix === '' ? name : `${prefix}_${name}`;
  259. };
  260. for (const expression of expressions) {
  261. const type = expression.__class__.__name__;
  262. for (const input of expression.inputs) {
  263. input._fullname = getName(context, getFullName(namePrefix, input._name));
  264. }
  265. for (const output of expression.outputs) {
  266. output._fullname = getName(context, getFullName(namePrefix, output._name));
  267. }
  268. switch (type) {
  269. case 'Input': {
  270. break;
  271. }
  272. case 'GetAttr': {
  273. if (expression.outputs[0].__class__.__name__ === 'TensorNode') {
  274. const tensor = parseGetAttr(tmodule, expression);
  275. const type = getTensorType(tensor.dtype, tensor.data.shape);
  276. const data = tensor.data.data;
  277. expression.outputs[0].initializer = new megengine.Tensor(expression.name, type, data);
  278. }
  279. break;
  280. }
  281. case 'Constant': {
  282. if (expression.outputs[0].__class__.__name__ === 'TensorNode') {
  283. const tensor = expression.value;
  284. const type = getTensorType(tensor.dtype, tensor.data.shape);
  285. const data = tensor.data.data;
  286. expression.outputs[0].initializer = new megengine.Tensor('', type, data);
  287. }
  288. break;
  289. }
  290. case 'CallMethod': {
  291. if (expression.method === '__call__') {
  292. const module = parseGetAttr(tmodule, expression.inputs[0].expr);
  293. const getModuleType = (obj) => {
  294. if (obj.module !== undefined) {
  295. return `${obj.module[0]}.${obj.module[1]}`;
  296. }
  297. return `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  298. };
  299. const moduleType = module.__class__.__name__ !== 'TracedModule' ? getModuleType(module) : 'TracedModule';
  300. if (moduleType === 'TracedModule') {
  301. const moduleName = expression.outputs[0]._name.endsWith("_out") ? expression.outputs[0]._name.substring(0, expression.outputs[0]._name.length - 4) : expression.outputs[0]._name;
  302. const prefix = getFullName(namePrefix, moduleName);
  303. const internalGraph = module.argdef_graph_map[expression.arg_def.toString()];
  304. for (let i = 0; i < expression.inputs.length; i++) {
  305. const actualName = getFullName(namePrefix, expression.inputs[i]._name);
  306. const internalName = getFullName(prefix, internalGraph._inputs[i]._name);
  307. context.set(internalName, actualName);
  308. }
  309. for (let i = 0; i < expression.outputs.length; i++) {
  310. const actualName = getFullName(namePrefix, expression.outputs[i]._name);
  311. const internalName = getFullName(prefix, internalGraph._outputs[i]._name);
  312. if (context.get(internalName) !== undefined) {
  313. context.set(actualName, context.get(internalName));
  314. } else {
  315. context.set(internalName, actualName);
  316. }
  317. }
  318. loadGraph(module, internalGraph, context, prefix, metadata, false);
  319. continue;
  320. }
  321. const item = { 'name': '', 'type': moduleType };
  322. let state = module.__class__.__name__ !== 'TracedModule' ? module.state : module;
  323. if (state === undefined) {
  324. state = module;
  325. }
  326. const node = getOpNode(metadata, item, expression, state);
  327. this.nodes.push(node);
  328. } else {
  329. const item = { 'name': '', 'type': expression.method };
  330. const [args, kwargs] = expression.arg_def.children_defs;
  331. const schema = metadata.type(expression.method);
  332. const state = parseArgs(args, kwargs, schema);
  333. const node = getOpNode(metadata, item, expression, state);
  334. this.nodes.push(node);
  335. }
  336. break;
  337. }
  338. case 'CallFunction': {
  339. const getFunctionType = (obj) => {
  340. if (obj.func.__module__ !== undefined) {
  341. return `${obj.func.__module__}.${obj.func.__name__}`;
  342. }
  343. return `${obj.func[0]}.${obj.func[1]}`;
  344. };
  345. const func = getFunctionType(expression);
  346. const item = { 'name': '', 'type': func };
  347. const [args, kwargs] = expression.arg_def.children_defs;
  348. const schema = metadata.type(func);
  349. const state = parseArgs(args, kwargs, schema);
  350. const node = getOpNode(metadata, item, expression, state);
  351. this.nodes.push(node);
  352. break;
  353. }
  354. case 'Apply': {
  355. const opdef = expression.opdef_state ? expression.opdef_state.opdef_type : expression.opdef.type;
  356. const item = { 'name': '', 'type': `${opdef.__module__}.${opdef.__name__}` };
  357. const node = getOpNode(metadata, item, expression, expression.opdef_state);
  358. this.nodes.push(node);
  359. break;
  360. }
  361. default: {
  362. break;
  363. }
  364. }
  365. }
  366. };
  367. if (obj.argdef_graph_map) {
  368. const [graph] = Object.values(obj.argdef_graph_map);
  369. loadGraph(obj, graph, new Map(), '', metadata, true);
  370. return;
  371. }
  372. const extraInfoNameset = new Set();
  373. const getExtraInfo = (opr) => {
  374. let name = opr.name;
  375. let repeatIdx = 0;
  376. while (extraInfoNameset.has(name)) {
  377. for (const id of opr.inputs) {
  378. name = `${name}[${id}]`;
  379. }
  380. name += repeatIdx;
  381. repeatIdx += 1;
  382. }
  383. extraInfoNameset.add(name);
  384. const type = opr.type.replace(/V(\d+)$/, '');
  385. const args = [];
  386. if (opr.tensors.length > 0) {
  387. const [tensor] = opr.tensors;
  388. const type = new megengine.TensorType(tensor.dtype.type, new megengine.TensorShape(tensor.shape));
  389. const data = tensor.data.byteLength !== 0 ? tensor.data.slice(0) : undefined;
  390. const initializer = opr.type === 'Host2DeviceCopy' ? undefined : new megengine.Tensor('', type, data);
  391. const quantization = tensor.dtype.param ? { scale: tensor.dtype.param.scale, zeroPoint: tensor.dtype.param.zero_point } : null;
  392. args.push(value(name, type, initializer, quantization));
  393. } else if (opr.shape) {
  394. const type = new megengine.TensorType('?', new megengine.TensorShape(opr.shape));
  395. args.push(value(name, type));
  396. } else {
  397. args.push(value(name));
  398. }
  399. return { name: name, type: type, args: args };
  400. };
  401. const getAllOprAndTensor = (oprs) => {
  402. const allOprAndTensor = new Map();
  403. for (const opr of oprs) {
  404. if (opr.type === 'MultipleDeviceTensorWithFormatHolder' || opr.outputs.length > 1) {
  405. if (opr.type === 'MultipleDeviceTensorWithFormatHolder' || opr.type === 'MultipleDeviceTensorHolder') {
  406. opr.type = 'ImmutableTensor';
  407. }
  408. for (let id = 0; id < opr.outputs.length; id++) {
  409. const keyId = opr.outputs[id];
  410. const name = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].name : String(keyId);
  411. const type = opr.type;
  412. const tensors = opr.tensors.length ? [opr.tensors[id]] : [];
  413. const onlyShape = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].shape : [];
  414. allOprAndTensor.set(keyId, { name: name, type: type, tensors: tensors, shape: onlyShape, inputs: opr.inputs, outputs: opr.outputs });
  415. const _opr = allOprAndTensor.get(keyId);
  416. _opr.extraInfo = getExtraInfo(_opr);
  417. }
  418. } else {
  419. const [keyId] = opr.outputs;
  420. opr.name = obj.middle_tensors[keyId] ? obj.middle_tensors[keyId].name : String(keyId);
  421. if (obj.middle_tensors[keyId] && obj.middle_tensors[keyId].shape) {
  422. opr.shape = obj.middle_tensors[keyId].shape;
  423. }
  424. allOprAndTensor.set(keyId, opr);
  425. const _opr = allOprAndTensor.get(keyId);
  426. _opr.extraInfo = getExtraInfo(_opr);
  427. }
  428. }
  429. return allOprAndTensor;
  430. };
  431. const allOprAndTensor = getAllOprAndTensor(obj.oprs);
  432. for (const op of Array.from(allOprAndTensor.values())) {
  433. if (op.type === 'Host2DeviceCopy') {
  434. const argument = new megengine.Argument('input', op.extraInfo.args);
  435. this.inputs.push(argument);
  436. } else if (op.type !== 'ImmutableTensor') {
  437. this.nodes.push(new megengine.Node(metadata, op, allOprAndTensor));
  438. }
  439. }
  440. for (let i = 0; i < obj.output_vars_idx.length; i++) {
  441. const id = obj.output_vars_idx[i].compact_id;
  442. const out_type = `output${i === 0 ? '' : i}`;
  443. const argument = new megengine.Argument(out_type, allOprAndTensor.get(id).extraInfo.args);
  444. this.outputs.push(argument);
  445. }
  446. }
  447. };
  448. megengine.Argument = class {
  449. constructor(name, value) {
  450. this.name = name;
  451. this.value = value;
  452. }
  453. };
  454. megengine.Value = class {
  455. constructor(name, type, initializer, quantization) {
  456. if (typeof name !== 'string') {
  457. throw new megengine.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  458. }
  459. this.name = name;
  460. this.type = type ? type : initializer && initializer.type ? initializer.type : null;
  461. this.initializer = initializer;
  462. if (quantization && ((quantization.scale !== undefined && quantization.scale !== 0) || quantization.zeroPoint !== undefined && quantization.zeroPoint !== 0)) {
  463. this.quantization = {
  464. type: 'linear',
  465. scale: [ quantization.scale ],
  466. offset: [ quantization.zeroPoint ]
  467. };
  468. }
  469. }
  470. };
  471. megengine.Node = class {
  472. constructor(metadata, item, allOprAndTensor) {
  473. this.name = '';
  474. this.type = Object.assign({}, metadata.type(item.type));
  475. this.type.name = this.type.name.replace(/V(\d+)$/, '');
  476. if (this.type.name.length > 4 && this.type.name.startsWith('__') && this.type.name.endsWith('__')) {
  477. this.type.name = this.type.name.substring(2, this.type.name.length - 2);
  478. }
  479. this.type.category = this.type.category? this.type.category: metadata.type(item.type.replace(/V(\d+)$/, '')).category;
  480. this.inputs = [];
  481. this.outputs = [];
  482. this.chain = [];
  483. this.attributes = [];
  484. if (item.inputs && item.outputs) {
  485. const inputSchemas = this.type && this.type.inputs ? [ ...this.type.inputs ] : [];
  486. for (let i = 0; i < item.inputs.length; i++) {
  487. const inputOpr = allOprAndTensor.get(item.inputs[i]);
  488. const inputSchema = inputSchemas.length > 0 ? inputSchemas.shift() : { name: (`input${i}`) };
  489. const argument = new megengine.Argument(inputSchema.name, inputOpr.extraInfo.args);
  490. this.inputs.push(argument);
  491. }
  492. const outputSchemas = this.type && this.type.outputs ? [ ...this.type.outputs ] : [];
  493. for (let i = 0; i < item.outputs.length; i++) {
  494. const outputOpr = allOprAndTensor.get(item.outputs[i]);
  495. const outputSchema = outputSchemas.length > 0 ? outputSchemas.shift() : { name: (`output${i}`) };
  496. const argument = new megengine.Argument(outputSchema.name, outputOpr.extraInfo.args);
  497. this.outputs.push(argument);
  498. }
  499. if (item.param) {
  500. for (const [name, value] of Object.entries(item.param)) {
  501. if (value !== null) {
  502. const attribute = new megengine.Attribute(metadata.attribute(item.param.constructor.name, name), name, value);
  503. this.attributes.push(attribute);
  504. }
  505. }
  506. }
  507. }
  508. }
  509. };
  510. megengine.Attribute = class {
  511. constructor(metadata, name, value) {
  512. this.type = metadata ? metadata.type : null;
  513. this.name = name;
  514. this.value = ArrayBuffer.isView(value) ? Array.from(value) : value;
  515. if (this.name === 'training') {
  516. this.visible = false;
  517. this.type = 'boolean';
  518. }
  519. if (megengine.schema) {
  520. if (megengine.schema.param[this.type]) {
  521. this.value = megengine.Utility.enum(megengine.schema.param, this.type, this.value);
  522. } else if (megengine.schema[this.type]) {
  523. this.value = megengine.Utility.enum(megengine.schema, this.type, this.value);
  524. } else if (megengine.schema.v2[this.type]) {
  525. this.value = megengine.Utility.enum(megengine.schema.v2, this.type, this.value);
  526. }
  527. }
  528. }
  529. };
  530. megengine.Tensor = class {
  531. constructor(name, type, data) {
  532. this.category = 'Tensor';
  533. this.name = name || '';
  534. this.type = type;
  535. this.values = data;
  536. }
  537. };
  538. megengine.TensorType = class {
  539. constructor(dataType, shape) {
  540. dataType = megengine.Utility.enum(megengine.schema, 'DTypeEnum', dataType);
  541. dataType = typeof dataType === 'string' ? dataType.toLowerCase() : dataType;
  542. megengine.TensorType._dataTypes = megengine.TensorType._dataTypes || new Map([
  543. [ 'bool', 'boolean' ],
  544. [ 'byte', 'uint8' ], [ 'quantizeds4asymm', 'uint8' ], [ 'quantizeds8asymm', 'uint8' ], [ 'uintb4', 'uint8' ],
  545. [ 'quantizeds1', 'int8' ], [ 'quantizeds4', 'int8' ], [ 'quantizeds8', 'int8' ], [ 'intb1', 'int8' ], [ 'intb2', 'int8' ], [ 'intb4', 'int8' ], [ 'qint8', 'int8' ],
  546. [ 'quantizeds16', 'int16' ],
  547. [ 'quantizeds32', 'int32' ]
  548. ]);
  549. this.dataType = megengine.TensorType._dataTypes.get(dataType) || dataType;
  550. this.shape = shape;
  551. }
  552. equals(obj) {
  553. return obj && this.dataType === obj.dataType && this.shape && this.shape.equals(obj.shape);
  554. }
  555. toString() {
  556. return this.dataType + this.shape.toString();
  557. }
  558. };
  559. megengine.TensorShape = class {
  560. constructor(dimensions) {
  561. this.dimensions = Array.from(dimensions || []);
  562. }
  563. equals(obj) {
  564. return obj && Array.isArray(obj.dimensions) &&
  565. Array.isArray(this.dimensions) && this.dimensions.length === obj.dimensions.length
  566. && obj.dimensions.every((value, index) => this.dimensions[index] === value);
  567. }
  568. toString() {
  569. if (this.dimensions && this.dimensions.length > 0) {
  570. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  571. }
  572. return '';
  573. }
  574. };
  575. megengine.Utility = class {
  576. static enum(schema, name, value) {
  577. const type = name && schema ? schema[name] : undefined;
  578. if (type) {
  579. megengine.Utility._enums = megengine.Utility._enums || new Map();
  580. if (!megengine.Utility._enums.has(name)) {
  581. const entries = new Map(Object.entries(type).map(([key, value]) => [ value, key ]));
  582. megengine.Utility._enums.set(name, entries);
  583. }
  584. const map = megengine.Utility._enums.get(name);
  585. if (map.has(value)) {
  586. return map.get(value);
  587. }
  588. }
  589. return value;
  590. }
  591. };
  592. megengine.Error = class extends Error {
  593. constructor(message) {
  594. super(message);
  595. this.name = 'Error loading MegEngine model.';
  596. }
  597. };
  598. export const ModelFactory = megengine.ModelFactory;