megengine.js 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. // Experimental
  2. var megengine = megengine || {};
  3. var base = base || require('./base');
  4. megengine.ModelFactory = class {
  5. match(context) {
  6. const obj = context.open('pkl');
  7. if (obj.__class__ && obj.__class__.__module__ === 'megengine.traced_module.traced_module' && obj.__class__.__name__ === 'TracedModule') {
  8. return 'megengine.pickle';
  9. }
  10. return '';
  11. }
  12. open(context) {
  13. return context.metadata('megengine-metadata.json').then((metadata) => {
  14. const obj = context.open('pkl');
  15. return new megengine.Model(metadata, obj);
  16. });
  17. }
  18. };
  19. megengine.Model = class {
  20. constructor(metadata, obj) {
  21. this._format = 'MegEngine' + (obj.dump_info && obj.dump_info.version ? ' v' + obj.dump_info.version : '');
  22. this._graphs = [ new megengine.Graph(metadata, obj) ];
  23. }
  24. get format() {
  25. return this._format;
  26. }
  27. get graphs() {
  28. return this._graphs;
  29. }
  30. };
  31. megengine.Graph = class {
  32. constructor(metadata, obj) {
  33. this._nodes = [];
  34. this._inputs = [];
  35. this._outputs = [];
  36. const loadgraph = (tmodule, igraph, context, name_prefix, metadata, isroot) =>{
  37. const expressions = igraph._exprs;
  38. const isTensor = (obj) => {
  39. return obj && obj.__class__ && obj.__class__.__module__ == 'megengine.tensor' && (obj.__class__.__name__ === 'Tensor' || obj.__class__.__name__ === 'Parameter');
  40. };
  41. const getTensorType = (dtype, shape) => {
  42. const dt = dtype !== null ? dtype.__name__ : null;
  43. return new megengine.TensorType(dt, new megengine.TensorShape(shape));
  44. };
  45. const getOpNode = (metadata, item, expr, state) => {
  46. const op = new megengine.Node(metadata, item);
  47. let inp_idx = 0;
  48. for (const i of expr.inputs) {
  49. if (i.__class__.__name__ !== 'ModuleNode') {
  50. const initializer = i.initializer !== undefined ? i.initializer : null;
  51. const inp_name = 'inp' + inp_idx;
  52. op._inputs.push(new megengine.Parameter(inp_name, true, [
  53. new megengine.Argument(i._fullname, getTensorType(i._dtype, i._shape), initializer)
  54. ]));
  55. inp_idx += 1;
  56. }
  57. }
  58. const out_idx = 0;
  59. let qparams = null;
  60. for (const o of expr.outputs) {
  61. if (o._qparams !== null) {
  62. qparams = o._qparams[1];
  63. }
  64. op._outputs.push(new megengine.Parameter('out' + out_idx, true, [
  65. new megengine.Argument(o._fullname, getTensorType(o._dtype, o._shape), null)
  66. ]));
  67. }
  68. if (qparams !== null) {
  69. state = state === null? {} : state;
  70. state['scale'] = qparams.scale;
  71. state['zero_point'] = qparams.zero_point;
  72. state['quant_dtype_meta'] = qparams.dtype_meta;
  73. }
  74. if (state !== null) {
  75. for (const key in state) {
  76. const isModule = (obj) => {
  77. return obj && (obj.state || obj._forward_pre_hooks);
  78. };
  79. if (!key.startsWith('_') && !isModule(state[key])) {
  80. if (!isTensor(state[key])) {
  81. op._attributes.push(new megengine.Attribute(null, key, state[key] !== null ? state[key] : 'None'));
  82. }
  83. else {
  84. const tensor = state[key];
  85. op._inputs.push(new megengine.Parameter(key, true, [
  86. new megengine.Argument('', getTensorType(tensor.dtype, tensor.data.shape), new megengine.Tensor(key, tensor))
  87. ]));
  88. }
  89. }
  90. }
  91. }
  92. return op;
  93. };
  94. if (isroot) {
  95. for (const node of igraph._inputs) {
  96. if (node.__class__.__name__ === 'ModuleNode') {
  97. continue;
  98. }
  99. this._inputs.push(new megengine.Parameter(node._name, true, [new megengine.Argument(node._name, getTensorType(node._dtype, node._shape), null)]));
  100. }
  101. for (const node of igraph._outputs) {
  102. this._outputs.push(new megengine.Parameter(node._name, true, [new megengine.Argument(node._name, getTensorType(node._dtype, node._shape), null)]));
  103. }
  104. }
  105. const parse_getattr = (tmodule, getattr_expr) => {
  106. let attr_name = getattr_expr.name.split('.');
  107. while (getattr_expr.inputs[0].expr.__class__.__name__ === 'GetAttr') {
  108. getattr_expr = getattr_expr.inputs[0].expr;
  109. attr_name = getattr_expr.name.split('.').concat(attr_name);
  110. }
  111. let attr_obj = tmodule;
  112. for (const n of attr_name) {
  113. attr_obj = attr_obj[n];
  114. }
  115. return attr_obj;
  116. };
  117. const parseargs = (args, kwargs, meta) => {
  118. const state = {};
  119. const schema = meta !== undefined ? meta.schema : undefined;
  120. let arg_idx = 0;
  121. let attr_name = '';
  122. const process_args = (inp, start_idx) => {
  123. while (typeof inp === 'string' && inp.indexOf('Tensor') !== -1) {
  124. inp = inp.replace('Tensor', 'inp' + start_idx);
  125. start_idx += 1;
  126. }
  127. return [inp, start_idx];
  128. };
  129. const formatTreeDef = (obj) => {
  130. if (obj.__class__.__name__ !== 'TreeDef' && obj.__class__.__name__ !== 'LeafDef') {
  131. throw new megengine.Error('formatTreeDef gets invalid argument');
  132. }
  133. if (obj.__class__.__name__ === 'TreeDef') {
  134. const type = typeof obj.type !== 'string' ? obj.type.__name__ : obj.type.split('.').slice(-1)[0];
  135. const list = obj.children_defs.map((child) => formatTreeDef(child));
  136. switch (type) {
  137. case 'tuple': {
  138. return '(' + list.join(',') + ')';
  139. }
  140. case 'slice': {
  141. return list.join(':');
  142. }
  143. case 'list': {
  144. return '[' + list.join(',') + ']';
  145. }
  146. case 'dict': {
  147. let content = '';
  148. for (let i = 0; i < this.children_defs.length; i++) {
  149. content += this.aux_data[i] + ':' + list[i];
  150. }
  151. return '{' + content + '}';
  152. }
  153. default: {
  154. return type + '(' + list.join(',') + ')';
  155. }
  156. }
  157. }
  158. if (obj.const_val !== null) {
  159. return obj.const_val;
  160. }
  161. else if (obj.type[0].__module__ !== undefined) {
  162. return obj.type[0].__name__;
  163. }
  164. return 'None';
  165. };
  166. let inp_idx = 0;
  167. for (const arg of args.children_defs) {
  168. if (schema === undefined || (schema.attributes.length !== args.children_defs.length && schema.varargs === null)) {
  169. attr_name = 'arg' + arg_idx;
  170. }
  171. else if (arg_idx < schema.attributes.length) {
  172. attr_name = schema.attributes[arg_idx];
  173. }
  174. else {
  175. attr_name = schema.varargs + (arg_idx - schema.attributes.length);
  176. }
  177. const rst = process_args(formatTreeDef(arg), inp_idx);
  178. state[attr_name] = rst[0];
  179. inp_idx = rst[1];
  180. arg_idx += 1;
  181. }
  182. for (let i = 0; i < kwargs.children_defs.length; i++) {
  183. const rst = process_args(formatTreeDef(kwargs.children_defs[i]), inp_idx);
  184. inp_idx = rst[1];
  185. state[kwargs.aux_data[i]] = rst[0];
  186. }
  187. return state;
  188. };
  189. const getname = (context, name) => {
  190. let rst = name;
  191. while (context.get(rst) !== undefined) {
  192. if (rst === context.get(rst)) {
  193. return rst;
  194. }
  195. rst = context.get(rst);
  196. }
  197. return rst;
  198. };
  199. const getfullname = (prefix, name) => {
  200. return prefix === '' ? name : prefix + '_' + name;
  201. };
  202. for (const expr of expressions) {
  203. const type = expr.__class__.__name__;
  204. for (const i of expr.inputs) {
  205. i._fullname = getname(context, getfullname(name_prefix, i._name));
  206. }
  207. for (const o of expr.outputs) {
  208. o._fullname = getname(context, getfullname(name_prefix, o._name));
  209. }
  210. switch (type) {
  211. case 'Input': {
  212. break;
  213. }
  214. case 'GetAttr': {
  215. if (expr.outputs[0].__class__.__name__ === 'TensorNode') {
  216. const tensor = parse_getattr(tmodule, expr);
  217. expr.outputs[0].initializer = new megengine.Tensor(expr.name, tensor);
  218. }
  219. break;
  220. }
  221. case 'Constant': {
  222. if (expr.outputs[0].__class__.__name__ === 'TensorNode') {
  223. expr.outputs[0].initializer = new megengine.Tensor('', expr.value);
  224. }
  225. break;
  226. }
  227. case 'CallMethod': {
  228. if (expr.method === '__call__') {
  229. const getattr_expr = expr.inputs[0].expr;
  230. const called_module = parse_getattr(tmodule, getattr_expr);
  231. const getModuleType = (obj) => {
  232. if (obj.module !== undefined) {
  233. return obj.module[0] + '.' + obj.module[1];
  234. }
  235. return obj.__class__.__module__ + '.' + obj.__class__.__name__;
  236. };
  237. const module_type = called_module.__class__.__name__ !== 'TracedModule' ? getModuleType(called_module) : 'TracedModule';
  238. if (module_type === 'TracedModule') {
  239. const prefix = getfullname(name_prefix, expr.inputs[0]._name);
  240. const internal_graph = called_module.argdef_graph_map[expr.arg_def.toString()];
  241. for (let i = 0; i < expr.inputs.length; i++) {
  242. const actual_name = getfullname(name_prefix, expr.inputs[i]._name);
  243. const internal_name = getfullname(prefix, internal_graph._inputs[i]._name);
  244. context.set(internal_name, actual_name);
  245. }
  246. for (let i = 0; i < expr.outputs.length; i++) {
  247. const actual_name = getfullname(name_prefix, expr.outputs[i]._name);
  248. const internal_name = getfullname(prefix, internal_graph._outputs[i]._name);
  249. context.set(internal_name, actual_name);
  250. }
  251. loadgraph(called_module, internal_graph, context, prefix, metadata, false);
  252. continue;
  253. }
  254. const item = { 'name': '', 'type': module_type };
  255. let state = called_module.__class__.__name__ !== 'TracedModule' ? called_module.state : called_module;
  256. if (state === undefined) {
  257. state = called_module;
  258. }
  259. this._nodes.push(getOpNode(metadata, item, expr, state));
  260. }
  261. else {
  262. const item = { 'name': '', 'type': expr.method };
  263. const args = expr.arg_def.children_defs[0];
  264. const kwargs = expr.arg_def.children_defs[1];
  265. const schema = metadata.type(expr.method);
  266. const state = parseargs(args, kwargs, schema);
  267. this._nodes.push(getOpNode(metadata, item, expr, state));
  268. }
  269. break;
  270. }
  271. case 'CallFunction': {
  272. const getFunctionType = (obj) => {
  273. if (obj.func.__module__ !== undefined) {
  274. return obj.func.__module__ + '.' + obj.func.__name__;
  275. }
  276. return obj.func[0] + '.' + obj.func[1];
  277. };
  278. const func = getFunctionType(expr);
  279. const item = { 'name': '', 'type': func };
  280. const args = expr.arg_def.children_defs[0];
  281. const kwargs = expr.arg_def.children_defs[1];
  282. const schema = metadata._types.get(func);
  283. const state = parseargs(args, kwargs, schema);
  284. this._nodes.push(getOpNode(metadata, item, expr, state));
  285. break;
  286. }
  287. case 'Apply': {
  288. const opdef = expr.opdef_state ? expr.opdef_state.opdef_type : expr.opdef.type;
  289. const item = { 'name': '', 'type': opdef.__module__ + '.' + opdef.__name__ };
  290. this._nodes.push(getOpNode(metadata, item, expr, expr.opdef_state));
  291. break;
  292. }
  293. default: {
  294. break;
  295. }
  296. }
  297. }
  298. };
  299. const graph = Object.values(obj.argdef_graph_map)[0];
  300. loadgraph(obj, graph, new Map(), '', metadata, true);
  301. }
  302. get inputs() {
  303. return this._inputs;
  304. }
  305. get outputs() {
  306. return this._outputs;
  307. }
  308. get nodes() {
  309. return this._nodes;
  310. }
  311. };
  312. megengine.Parameter = class {
  313. constructor(name, visible, args) {
  314. this._name = name;
  315. this._visible = visible;
  316. this._arguments = args;
  317. }
  318. get name() {
  319. return this._name;
  320. }
  321. get visible() {
  322. return this._visible;
  323. }
  324. get arguments() {
  325. return this._arguments;
  326. }
  327. };
  328. megengine.Argument = class {
  329. constructor(name, type, initializer) {
  330. if (typeof name !== 'string') {
  331. throw new megengine.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  332. }
  333. this._name = name;
  334. this._type = type;
  335. this._initializer = initializer;
  336. }
  337. get name() {
  338. return this._name;
  339. }
  340. get type() {
  341. if (this._initializer) {
  342. return this._initializer.type;
  343. }
  344. return this._type;
  345. }
  346. get initializer() {
  347. return this._initializer;
  348. }
  349. };
  350. megengine.Node = class {
  351. constructor(metadata, item) {
  352. this._name = item.name || '';
  353. this._type = Object.assign({}, metadata.type(item.type));
  354. if (this._type.name.length > 4 && this._type.name.startsWith('__') && this._type.name.endsWith('__')) {
  355. this._type.name = this._type.name.substring(2, this._type.name.length - 2);
  356. }
  357. this._inputs = [];
  358. this._outputs = [];
  359. this._chain = [];
  360. this._attributes = [];
  361. }
  362. get name() {
  363. return this._name;
  364. }
  365. get type() {
  366. return this._type;
  367. }
  368. get attributes() {
  369. return this._attributes;
  370. }
  371. get inputs() {
  372. return this._inputs;
  373. }
  374. get outputs() {
  375. return this._outputs;
  376. }
  377. get nodes() {
  378. return this._nodes;
  379. }
  380. };
  381. megengine.Attribute = class {
  382. constructor(metadata, name, value) {
  383. this._name = name;
  384. this._value = value;
  385. if (this._name === 'training') {
  386. this._visible = false;
  387. this._type = 'boolean';
  388. }
  389. }
  390. get type() {
  391. return this._type;
  392. }
  393. get name() {
  394. return this._name;
  395. }
  396. get value() {
  397. return this._value;
  398. }
  399. get visible() {
  400. return this._visible == false ? false : true;
  401. }
  402. };
  403. megengine.Tensor = class {
  404. constructor(name, tensor) {
  405. this._name = name || '';
  406. this._type = new megengine.TensorType(tensor.dtype.__name__, new megengine.TensorShape(tensor.data.shape));
  407. this._data = tensor.data.data;
  408. }
  409. get kind() {
  410. return 'Tensor';
  411. }
  412. get name() {
  413. return this._name;
  414. }
  415. get type() {
  416. return this._type;
  417. }
  418. get state() {
  419. return this._context().state;
  420. }
  421. get value() {
  422. const context = this._context();
  423. if (context.state) {
  424. return null;
  425. }
  426. context.limit = Number.MAX_SAFE_INTEGER;
  427. return this._decode(context, 0);
  428. }
  429. toString() {
  430. const context = this._context();
  431. if (context.state) {
  432. return '';
  433. }
  434. context.limit = 10000;
  435. const value = this._decode(context, 0);
  436. return megengine.Tensor._stringify(value, '', ' ');
  437. }
  438. _context() {
  439. const context = {};
  440. context.state = null;
  441. context.index = 0;
  442. context.count = 0;
  443. if (!this._type.dataType) {
  444. context.state = 'Tensor has no data type.';
  445. return context;
  446. }
  447. switch (this._type.dataType) {
  448. case 'boolean':
  449. case 'uint8':
  450. case 'qint8':
  451. case 'int8':
  452. case 'int16':
  453. case 'int32':
  454. case 'int64':
  455. case 'float16':
  456. case 'float32':
  457. case 'float64':
  458. case 'bfloat16':
  459. break;
  460. default:
  461. context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
  462. return context;
  463. }
  464. if (!this._type.shape) {
  465. context.state = 'Tensor has no dimensions.';
  466. return context;
  467. }
  468. if (!this._data) {
  469. context.state = 'Tensor data is empty.';
  470. return context;
  471. }
  472. try {
  473. context.data = this._data instanceof Uint8Array ? this._data : this._data.peek();
  474. }
  475. catch (err) {
  476. context.state = err.message;
  477. return context;
  478. }
  479. context.dataType = this._type.dataType;
  480. context.dimensions = this._type.shape.dimensions;
  481. context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
  482. return context;
  483. }
  484. _decode(context, dimension) {
  485. const results = [];
  486. const dimensions = (context.dimensions.length == 0) ? [1] : context.dimensions;
  487. const size = dimensions[dimension];
  488. if (dimension == dimensions.length - 1) {
  489. for (let i = 0; i < size; i++) {
  490. if (context.count > context.limit) {
  491. results.push('...');
  492. return results;
  493. }
  494. switch (context.dataType) {
  495. case 'boolean':
  496. results.push(context.dataView.getUint8(context.index) === 0 ? false : true);
  497. context.index++;
  498. context.count++;
  499. break;
  500. case 'uint8':
  501. results.push(context.dataView.getUint8(context.index));
  502. context.index++;
  503. context.count++;
  504. break;
  505. case 'qint8':
  506. case 'int8':
  507. results.push(context.dataView.getInt8(context.index));
  508. context.index++;
  509. context.count++;
  510. break;
  511. case 'int16':
  512. results.push(context.dataView.getInt16(context.index, true));
  513. context.index += 2;
  514. context.count++;
  515. break;
  516. case 'int32':
  517. results.push(context.dataView.getInt32(context.index, true));
  518. context.index += 4;
  519. context.count++;
  520. break;
  521. case 'int64':
  522. results.push(context.dataView.getInt64(context.index, true));
  523. context.index += 8;
  524. context.count++;
  525. break;
  526. case 'float16':
  527. results.push(context.dataView.getFloat16(context.index, true));
  528. context.index += 2;
  529. context.count++;
  530. break;
  531. case 'float32':
  532. results.push(context.dataView.getFloat32(context.index, true));
  533. context.index += 4;
  534. context.count++;
  535. break;
  536. case 'float64':
  537. results.push(context.dataView.getFloat64(context.index, true));
  538. context.index += 8;
  539. context.count++;
  540. break;
  541. case 'bfloat16':
  542. results.push(context.dataView.getBfloat16(context.index, true));
  543. context.index += 2;
  544. context.count++;
  545. break;
  546. default:
  547. throw new megengine.Error("Unsupported tensor data type '" + context.dataType + "'.");
  548. }
  549. }
  550. }
  551. else {
  552. for (let j = 0; j < size; j++) {
  553. if (context.count > context.limit) {
  554. results.push('...');
  555. return results;
  556. }
  557. results.push(this._decode(context, dimension + 1));
  558. }
  559. }
  560. if (context.dimensions.length == 0) {
  561. return results[0];
  562. }
  563. return results;
  564. }
  565. static _stringify(value, indentation, indent) {
  566. if (Array.isArray(value)) {
  567. const result = [];
  568. result.push(indentation + '[');
  569. const items = value.map((item) => megengine.Tensor._stringify(item, indentation + indent, indent));
  570. if (items.length > 0) {
  571. result.push(items.join(',\n'));
  572. }
  573. result.push(indentation + ']');
  574. return result.join('\n');
  575. }
  576. if (value && (value instanceof base.Int64 || value instanceof base.Uint64)) {
  577. return indentation + value.toString();
  578. }
  579. if (typeof value == 'string') {
  580. return indentation + value;
  581. }
  582. if (value == Infinity) {
  583. return indentation + 'Infinity';
  584. }
  585. if (value == -Infinity) {
  586. return indentation + '-Infinity';
  587. }
  588. if (isNaN(value)) {
  589. return indentation + 'NaN';
  590. }
  591. return indentation + value.toString();
  592. }
  593. };
  594. megengine.TensorType = class {
  595. constructor(dataType, shape) {
  596. this._dataType = dataType;
  597. this._shape = shape;
  598. }
  599. get dataType() {
  600. return this._dataType;
  601. }
  602. get shape() {
  603. return this._shape;
  604. }
  605. toString() {
  606. return this._dataType + this._shape.toString();
  607. }
  608. };
  609. megengine.TensorShape = class {
  610. constructor(dimensions) {
  611. this._dimensions = dimensions || [];
  612. }
  613. get dimensions() {
  614. return this._dimensions;
  615. }
  616. toString() {
  617. if (this._dimensions && this._dimensions.length > 0) {
  618. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  619. }
  620. return '';
  621. }
  622. };
  623. megengine.Error = class extends Error {
  624. constructor(message) {
  625. super(message);
  626. this.name = 'Error loading MegEngine model.';
  627. }
  628. };
  629. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  630. module.exports.ModelFactory = megengine.ModelFactory;
  631. }