paddle.js 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100
  1. var paddle = paddle || {};
  2. var flatbuffers = flatbuffers || require('./flatbuffers');
  3. var protobuf = protobuf || require('./protobuf');
  4. var python = python || require('./python');
  5. var base = base || require('./base');
  6. paddle.ModelFactory = class {
  7. match(context) {
  8. const identifier = context.identifier;
  9. const extension = identifier.split('.').pop().toLowerCase();
  10. if (identifier === '__model__' || extension === '__model__' || extension === 'paddle' || extension === 'pdmodel') {
  11. const tags = context.tags('pb');
  12. if (tags.get(1) === 2) {
  13. return 'paddle.pb';
  14. }
  15. }
  16. if (extension === 'pbtxt' || extension === 'txt') {
  17. const tags = context.tags('pbtxt');
  18. if (tags.has('blocks')) {
  19. return 'paddle.pbtxt';
  20. }
  21. }
  22. const stream = context.stream;
  23. if (stream && stream.length > 16 && stream.peek(16).every((value) => value === 0x00)) {
  24. return 'paddle.params';
  25. }
  26. if (paddle.Pickle.open(context)) {
  27. return 'paddle.pickle';
  28. }
  29. if (paddle.Entries.open(context)) {
  30. return 'paddle.entries';
  31. }
  32. if (paddle.NaiveBuffer.open(context)) {
  33. return 'paddle.naive';
  34. }
  35. return undefined;
  36. }
  37. open(context, match) {
  38. return context.metadata('paddle-metadata.json').then((metadata) => {
  39. switch (match) {
  40. case 'paddle.naive': {
  41. return context.require('./paddle-schema').then(() => {
  42. paddle.schema = flatbuffers.get('paddlelite').paddle.lite.fbs.proto;
  43. const file = paddle.NaiveBuffer.open(context);
  44. return new paddle.Model(metadata, file.format, file.model, file.weights);
  45. });
  46. }
  47. default: {
  48. return context.require('./paddle-proto').then(() => {
  49. paddle.proto = protobuf.get('paddle').paddle.framework.proto;
  50. const identifier = context.identifier;
  51. const parts = identifier.split('.');
  52. const extension = parts.pop().toLowerCase();
  53. const base = parts.join('.');
  54. const openProgram = (stream, match) => {
  55. const program = {};
  56. switch (match) {
  57. case 'paddle.pbtxt': {
  58. try {
  59. const reader = protobuf.TextReader.open(stream);
  60. program.desc = paddle.proto.ProgramDesc.decodeText(reader);
  61. }
  62. catch (error) {
  63. const message = error && error.message ? error.message : error.toString();
  64. throw new paddle.Error('File text format is not paddle.ProgramDesc (' + message.replace(/\.$/, '') + ').');
  65. }
  66. break;
  67. }
  68. case 'paddle.pb': {
  69. try {
  70. const reader = protobuf.BinaryReader.open(stream);
  71. program.desc = paddle.proto.ProgramDesc.decode(reader);
  72. }
  73. catch (error) {
  74. const message = error && error.message ? error.message : error.toString();
  75. throw new paddle.Error('File format is not paddle.ProgramDesc (' + message.replace(/\.$/, '') + ').');
  76. }
  77. break;
  78. }
  79. default: {
  80. throw new paddle.Error("Unsupported Paddle format '" + match + "'.");
  81. }
  82. }
  83. const formatVersion = (version) => {
  84. if (version && version.version && version.version.toNumber) {
  85. const number = version.version.toNumber();
  86. if (number > 0) {
  87. const list = [ Math.floor(number / 1000000) % 1000, Math.floor(number / 1000) % 1000, number % 1000 ];
  88. if (list.slice(-1).pop() === 0) {
  89. list.pop();
  90. if (list.slice(-1).pop() === 0) {
  91. list.pop();
  92. }
  93. }
  94. return ' v' + list.map((item) => item.toString()).join('.');
  95. }
  96. }
  97. return '';
  98. };
  99. program.format = 'PaddlePaddle' + formatVersion(program.desc.version);
  100. const variables = new Set();
  101. for (const block of program.desc.blocks) {
  102. const blockVars = new Set();
  103. for (const variable of block.vars) {
  104. if (variable.persistable && variable.type &&
  105. variable.type.type != paddle.DataType.FETCH_LIST &&
  106. variable.type.type != paddle.DataType.FEED_MINIBATCH) {
  107. blockVars.add(variable.name);
  108. }
  109. }
  110. for (const op of block.ops) {
  111. for (const input of op.inputs) {
  112. for (const argument of input.arguments) {
  113. if (blockVars.has(argument)) {
  114. variables.add(argument);
  115. }
  116. }
  117. }
  118. }
  119. }
  120. program.vars = Array.from(variables).sort();
  121. return program;
  122. };
  123. const createModel = (metadata, format, desc, tensors) => {
  124. return new paddle.Model(metadata, format, desc, tensors);
  125. };
  126. const loadParams = (stream) => {
  127. const params = [];
  128. while (stream.position < stream.length) {
  129. const tensor = paddle.Utility.openTensorDesc(stream);
  130. params.push(tensor);
  131. }
  132. return params;
  133. };
  134. const mapParams = (params, program) => {
  135. const weights = new Map();
  136. const vars = program.vars.slice();
  137. for (const param of params) {
  138. weights.set(vars.shift(), param);
  139. }
  140. return weights;
  141. };
  142. switch (match) {
  143. case 'paddle.pickle': {
  144. const container = paddle.Pickle.open(context);
  145. return createModel(metadata, container.format, null, container.weights);
  146. }
  147. case 'paddle.entries': {
  148. const container = paddle.Entries.open(context);
  149. return createModel(metadata, container.format, null, container.weights);
  150. }
  151. case 'paddle.params': {
  152. const file = identifier !== 'params' ? base + '.pdmodel' : 'model';
  153. const params = loadParams(context.stream);
  154. return context.request(file, null).then((stream) => {
  155. const program = openProgram(stream, 'paddle.pb');
  156. const weights = mapParams(params, program);
  157. return createModel(metadata, program.format, program.desc, weights);
  158. }).catch(() => {
  159. const weights = new Map(params.map((param, index) => [ index.toString(), param ]));
  160. return createModel(metadata, 'PaddlePaddle Inference Weights', null, weights);
  161. });
  162. }
  163. case 'paddle.pb':
  164. case 'paddle.pbtxt': {
  165. const loadEntries = (context, program) => {
  166. const promises = program.vars.map((name) => context.request(name, null).then((stream) => stream).catch(() => null));
  167. return Promise.all(promises).then((streams) => {
  168. const params = streams.map((stream) => stream ? paddle.Utility.openTensorDesc(stream) : null);
  169. const weights = mapParams(params, program);
  170. return createModel(metadata, program.format, program.desc, weights);
  171. });
  172. };
  173. const openNumPyArrayPickle = (stream) => {
  174. const execution = new python.Execution(null);
  175. const unpickler = python.Unpickler.open(stream, execution);
  176. const obj = unpickler.load();
  177. const container = new paddle.Pickle(obj);
  178. return container.weights || new Map();
  179. };
  180. const program = openProgram(context.stream, match);
  181. if (extension === 'pdmodel') {
  182. return context.request(base + '.pdiparams', null).then((stream) => {
  183. const params = loadParams(stream);
  184. const weights = mapParams(params, program);
  185. return createModel(metadata, program.format, program.desc, weights);
  186. }).catch((/* err */) => {
  187. return context.request(base + '.pdparams', null).then((stream) => {
  188. const weights = openNumPyArrayPickle(stream);
  189. return context.request(base + '.pdopt', null).then((stream) => {
  190. for (const entry of openNumPyArrayPickle(stream)) {
  191. if (!weights.has(entry[0])) {
  192. weights.set(entry[0], entry[1]);
  193. }
  194. }
  195. return createModel(metadata, program.format, program.desc, weights);
  196. }).catch((/* err */) => {
  197. return createModel(metadata, program.format, program.desc, weights);
  198. });
  199. }).catch((/* err */) => {
  200. return context.request(base + '.pdopt', null).then((stream) => {
  201. const weights = openNumPyArrayPickle(stream);
  202. return createModel(metadata, program.format, program.desc, weights);
  203. }).catch((/* err */) => {
  204. return loadEntries(context, program);
  205. });
  206. });
  207. });
  208. }
  209. if (identifier === 'model') {
  210. return context.request('params', null).then((stream) => {
  211. const params = loadParams(stream);
  212. const weights = mapParams(params, program);
  213. return createModel(metadata, program.format, program.desc, weights);
  214. }).catch((/* err */) => {
  215. return loadEntries(context, program);
  216. });
  217. }
  218. return loadEntries(context, program);
  219. }
  220. default: {
  221. throw new paddle.Error("Unsupported PaddlePaddle format '" + match + "'.");
  222. }
  223. }
  224. });
  225. }
  226. }
  227. });
  228. }
  229. };
  230. paddle.Model = class {
  231. constructor(metadata, format, programDesc, tensors) {
  232. this._format = format;
  233. this._graphs = programDesc ?
  234. programDesc.blocks.map((block) => new paddle.Graph(metadata, block, tensors)) :
  235. [ new paddle.Graph(metadata, null, tensors) ];
  236. }
  237. get format() {
  238. return this._format;
  239. }
  240. get graphs() {
  241. return this._graphs;
  242. }
  243. };
  244. paddle.Graph = class {
  245. constructor(metadata, block, tensors) {
  246. this._nodes = [];
  247. this._inputs = [];
  248. this._outputs = [];
  249. if (block) {
  250. this._name = block.idx.toString();
  251. const args = new Map();
  252. for (const variable of block.vars) {
  253. const type = variable.type && variable.type.type && variable.type.lod_tensor && variable.type.lod_tensor.tensor ? paddle.Utility.createTensorType(variable.type.lod_tensor.tensor.data_type, variable.type.lod_tensor.tensor.dims) : null;
  254. const tensor = variable.persistable && variable.type && variable.type.type != paddle.DataType.FETCH_LIST && variable.type.type != paddle.DataType.FEED_MINIBATCH ? (tensors.get(variable.name) || new paddle.Tensor(type)) : null;
  255. args.set(variable.name, new paddle.Argument(variable.name, type, tensor));
  256. }
  257. const scope = {};
  258. for (let i = 0; i < block.ops.length; i++) {
  259. for (const input of block.ops[i].inputs) {
  260. input.arguments = input.arguments.map((argument) => scope[argument] ? scope[argument] : argument);
  261. }
  262. for (const output of block.ops[i].outputs) {
  263. output.arguments = output.arguments.map((argument) => {
  264. if (scope[argument]) {
  265. const next = argument + '\n' + i.toString(); // custom argument id
  266. scope[argument] = next;
  267. return next;
  268. }
  269. scope[argument] = argument;
  270. return argument;
  271. });
  272. }
  273. }
  274. for (const op of block.ops) {
  275. for (const input of op.inputs) {
  276. for (const argument of input.arguments) {
  277. const name = argument;
  278. if (!args.has(name)) {
  279. args.set(name, new paddle.Argument(name, null, null));
  280. }
  281. }
  282. }
  283. for (const output of op.outputs) {
  284. for (const argument of output.arguments) {
  285. const name = argument;
  286. if (!args.has(name)) {
  287. args.set(name, new paddle.Argument(name, null, null));
  288. }
  289. }
  290. }
  291. }
  292. let lastNode = null;
  293. let lastOutput = null;
  294. for (const op of block.ops) {
  295. if (op.type == 'feed') {
  296. const inputName = op.attrs.filter((attr) => attr.name == 'col')[0].i.toString();
  297. this._inputs.push(new paddle.Parameter(inputName, op.outputs[0].arguments.map((id) => args.get(id))));
  298. }
  299. else if (op.type == 'fetch') {
  300. const outputName = op.attrs.filter((attr) => attr.name == 'col')[0].i.toString();
  301. this._outputs.push(new paddle.Parameter(outputName, op.inputs[0].arguments.map((id) => args.get(id))));
  302. }
  303. else {
  304. const node = new paddle.Node(metadata, op, args);
  305. if (op.inputs.length == 1 && op.inputs[0].arguments.length == 1 &&
  306. op.outputs.length >= 1 && op.outputs[0].arguments.length == 1 &&
  307. op.inputs[0].arguments[0].split('\n').shift() == op.outputs[0].arguments[0].split('\n').shift() &&
  308. lastNode &&
  309. lastOutput == op.inputs[0].arguments[0].split('\n').shift()) {
  310. lastNode.chain.push(node);
  311. }
  312. else {
  313. this._nodes.push(node);
  314. lastNode = null;
  315. lastOutput = null;
  316. if (op.outputs.length == 1 && op.outputs[0].arguments.length == 1) {
  317. lastNode = node;
  318. lastOutput = op.outputs[0].arguments[0].split('\n').shift();
  319. }
  320. }
  321. }
  322. }
  323. }
  324. else {
  325. const args = new Map();
  326. const ops = new Map();
  327. for (const pair of tensors) {
  328. const name = pair[0];
  329. const tensor = pair[1];
  330. args.set(name, new paddle.Argument(name, tensor.type, tensor));
  331. const separator = name.indexOf('.') !== -1 ? '.' : '_';
  332. const regex = /(.*)_((w_attr|scale|weights|offset|b|w|b_attr)_(moment|beta|velocity|mean_square|mean_grad).*)/;
  333. const parts = separator === '.' ? name.split(separator) : (regex.test(name) ? regex.exec(name).slice(1, 3) : [ '', name ]);
  334. const parameter_name = parts.pop();
  335. const op_name = parts.join(separator);
  336. if (!ops.has(op_name)) {
  337. ops.set(op_name, { name: op_name, type: 'Weights', inputs: [] });
  338. }
  339. const op = ops.get(op_name);
  340. op.inputs.push({ parameter: parameter_name, arguments: [ name ] });
  341. }
  342. for (const pair of ops) {
  343. const op = pair[1];
  344. this._nodes.push(new paddle.Node(metadata, op, args));
  345. }
  346. }
  347. }
  348. get name() {
  349. return this._name;
  350. }
  351. get inputs() {
  352. return this._inputs;
  353. }
  354. get outputs() {
  355. return this._outputs;
  356. }
  357. get nodes() {
  358. return this._nodes;
  359. }
  360. };
  361. paddle.Parameter = class {
  362. constructor(name, args) {
  363. this._name = name;
  364. this._arguments = args;
  365. }
  366. get name() {
  367. return this._name;
  368. }
  369. get visible() {
  370. return true;
  371. }
  372. get arguments() {
  373. return this._arguments;
  374. }
  375. };
  376. paddle.Argument = class {
  377. constructor(name, type, initializer) {
  378. if (typeof name !== 'string') {
  379. throw new paddle.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  380. }
  381. this._name = name;
  382. this._type = type || null;
  383. this._initializer = initializer || null;
  384. }
  385. get name() {
  386. return this._name;
  387. }
  388. get type() {
  389. if (this._type) {
  390. return this._type;
  391. }
  392. if (this._initializer) {
  393. return this._initializer.type;
  394. }
  395. return null;
  396. }
  397. get initializer() {
  398. return this._initializer;
  399. }
  400. };
  401. paddle.Node = class {
  402. constructor(metadata, op, args) {
  403. const type = op.type;
  404. this._type = metadata.type(type) || { name: type };
  405. this._name = op.name || '';
  406. this._attributes = [];
  407. this._inputs = [];
  408. this._outputs = [];
  409. this._chain = [];
  410. if (op.attrs) {
  411. this._attributes = op.attrs.map((attr) => new paddle.Attribute(metadata.attribute(type, this._name), attr));
  412. }
  413. if (op.inputs) {
  414. for (const input of op.inputs) {
  415. if (input.arguments.length > 0) {
  416. this._inputs.push(new paddle.Parameter(input.parameter, input.arguments.map((name) => args.get(name))));
  417. }
  418. }
  419. }
  420. if (op.outputs) {
  421. for (const output of op.outputs) {
  422. if (output.arguments.length > 0) {
  423. this._outputs.push(new paddle.Parameter(output.parameter, output.arguments.map((name) => args.get(name))));
  424. }
  425. }
  426. }
  427. this._update(this._inputs, 'X');
  428. this._update(this._inputs, 'Input');
  429. this._update(this._outputs, 'Y');
  430. this._update(this._outputs, 'Out');
  431. }
  432. get type() {
  433. return this._type;
  434. }
  435. get name() {
  436. return this._name;
  437. }
  438. get attributes() {
  439. return this._attributes;
  440. }
  441. get inputs() {
  442. return this._inputs;
  443. }
  444. get outputs() {
  445. return this._outputs;
  446. }
  447. get chain() {
  448. return this._chain;
  449. }
  450. _update(list, name) {
  451. let item = null;
  452. for (let i = 0; i < list.length; i++) {
  453. if (list[i].name == name) {
  454. item = list[i];
  455. list.splice(i, 1);
  456. break;
  457. }
  458. }
  459. if (item) {
  460. list.splice(0, 0, item);
  461. }
  462. }
  463. };
  464. paddle.Attribute = class {
  465. constructor(schema, attr) {
  466. this._name = attr.name;
  467. this._value = '?';
  468. switch (attr.type) {
  469. case paddle.AttributeType.STRING:
  470. this._type = 'string';
  471. this._value = attr.s;
  472. break;
  473. case paddle.AttributeType.STRINGS:
  474. this._type = 'string[]';
  475. this._value = Array.from(attr.strings);
  476. break;
  477. case paddle.AttributeType.BOOLEAN:
  478. this._type = 'boolean';
  479. this._value = attr.b;
  480. break;
  481. case paddle.AttributeType.BOOLEANS:
  482. this._type = 'boolean[]';
  483. this._value = Array.from(attr.bools);
  484. break;
  485. case paddle.AttributeType.FLOAT:
  486. this._type = 'float32';
  487. this._value = attr.f;
  488. break;
  489. case paddle.AttributeType.FLOATS:
  490. this._type = 'float[]';
  491. this._value = Array.from(attr.floats);
  492. break;
  493. case paddle.AttributeType.INT:
  494. this._type = 'int32';
  495. this._value = attr.i;
  496. break;
  497. case paddle.AttributeType.INTS:
  498. this._type = 'int32[]';
  499. this._value = Array.from(attr.ints);
  500. break;
  501. case paddle.AttributeType.LONG:
  502. this._type = 'int64';
  503. break;
  504. case paddle.AttributeType.LONGS:
  505. this._type = 'int64[]';
  506. break;
  507. default:
  508. break;
  509. }
  510. switch (this._name) {
  511. case 'use_mkldnn':
  512. case 'use_cudnn':
  513. case 'op_callstack':
  514. case 'op_role':
  515. case 'op_role_var':
  516. case 'op_namescope':
  517. case 'is_test':
  518. this._visible = false;
  519. break;
  520. default:
  521. break;
  522. }
  523. if (schema) {
  524. if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  525. const defaultValue = schema.default;
  526. const value = this._value;
  527. if (defaultValue == value) {
  528. this._visible = false;
  529. }
  530. else if (Array.isArray(value) && Array.isArray(defaultValue) && value.length == defaultValue.length) {
  531. if (value.every((item, index) => item == defaultValue[index])) {
  532. this._visible = false;
  533. }
  534. }
  535. }
  536. }
  537. }
  538. get name() {
  539. return this._name;
  540. }
  541. get type() {
  542. return this._type;
  543. }
  544. get value() {
  545. return this._value;
  546. }
  547. get visible() {
  548. return this._visible == false ? false : true;
  549. }
  550. };
  551. paddle.Tensor = class {
  552. constructor(type, data, kind) {
  553. this._type = type;
  554. this._data = data;
  555. this._kind = kind || '';
  556. }
  557. get kind() {
  558. return this._kind;
  559. }
  560. get type() {
  561. return this._type;
  562. }
  563. get state() {
  564. return this._context().state || null;
  565. }
  566. get value() {
  567. const context = this._context();
  568. if (context.state) {
  569. return null;
  570. }
  571. context.limit = Number.MAX_SAFE_INTEGER;
  572. return this._decode(context, 0);
  573. }
  574. toString() {
  575. const context = this._context();
  576. if (context.state) {
  577. return '';
  578. }
  579. context.limit = 10000;
  580. const value = this._decode(context, 0);
  581. return paddle.Tensor._stringify(value, '', ' ');
  582. }
  583. _context() {
  584. const context = {};
  585. context.index = 0;
  586. context.count = 0;
  587. context.state = null;
  588. if (!this._data) {
  589. context.state = 'Tensor data is empty.';
  590. return context;
  591. }
  592. if (!this._type) {
  593. context.state = 'Tensor has no data type.';
  594. return context;
  595. }
  596. context.dataType = this._type.dataType;
  597. context.shape = this._type.shape.dimensions;
  598. context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  599. switch (context.dataType) {
  600. case 'float32':
  601. case 'int32':
  602. case 'int64':
  603. break;
  604. default:
  605. context.state = "Tensor data type '" + context.dataType + "' is not implemented.";
  606. break;
  607. }
  608. return context;
  609. }
  610. _decode(context, dimension) {
  611. const shape = context.shape.length !== 0 ? context.shape : [ 1 ];
  612. const results = [];
  613. const size = shape[dimension];
  614. if (dimension == shape.length - 1) {
  615. for (let i = 0; i < size; i++) {
  616. if (context.count > context.limit) {
  617. results.push('...');
  618. return results;
  619. }
  620. switch (context.dataType) {
  621. case 'float32':
  622. results.push(context.view.getFloat32(context.index, true));
  623. context.index += 4;
  624. context.count++;
  625. break;
  626. case 'int32':
  627. results.push(context.view.getInt32(context.index, true));
  628. context.index += 4;
  629. context.count++;
  630. break;
  631. case 'int64':
  632. results.push(context.view.getInt64(context.index, true));
  633. context.index += 8;
  634. context.count++;
  635. break;
  636. default:
  637. throw new paddle.Error("Unsupported tensor data type '" + context.dataType + "'.");
  638. }
  639. }
  640. }
  641. else {
  642. for (let j = 0; j < size; j++) {
  643. if (context.count > context.limit) {
  644. results.push('...');
  645. return results;
  646. }
  647. results.push(this._decode(context, dimension + 1));
  648. }
  649. }
  650. if (context.shape.length == 0) {
  651. return results[0];
  652. }
  653. return results;
  654. }
  655. static _stringify(value, indentation, indent) {
  656. if (Array.isArray(value)) {
  657. const result = [];
  658. result.push(indentation + '[');
  659. const items = value.map((item) => paddle.Tensor._stringify(item, indentation + indent, indent));
  660. if (items.length > 0) {
  661. result.push(items.join(',\n'));
  662. }
  663. result.push(indentation + ']');
  664. return result.join('\n');
  665. }
  666. if (typeof value == 'string') {
  667. return indentation + value;
  668. }
  669. if (value == Infinity) {
  670. return indentation + 'Infinity';
  671. }
  672. if (value == -Infinity) {
  673. return indentation + '-Infinity';
  674. }
  675. if (isNaN(value)) {
  676. return indentation + 'NaN';
  677. }
  678. return indentation + value.toString();
  679. }
  680. };
  681. paddle.TensorType = class {
  682. constructor(dataType, shape) {
  683. this._dataType = dataType;
  684. this._shape = shape;
  685. }
  686. get dataType() {
  687. return this._dataType;
  688. }
  689. get shape() {
  690. return this._shape;
  691. }
  692. get denotation() {
  693. return this._denotation;
  694. }
  695. toString() {
  696. return this._dataType + this._shape.toString();
  697. }
  698. };
  699. paddle.TensorShape = class {
  700. constructor(dimensions) {
  701. dimensions = dimensions.map((dimension) => Number.isInteger(dimension) ? dimension : dimension.toNumber());
  702. this._dimensions = dimensions.map((dimension) => {
  703. return dimension != -1 ? dimension : '?';
  704. });
  705. }
  706. get dimensions() {
  707. return this._dimensions;
  708. }
  709. toString() {
  710. return (this._dimensions && this._dimensions.length) ? ('[' + this._dimensions.join(',') + ']') : '';
  711. }
  712. };
  713. paddle.Entries = class {
  714. static open(context) {
  715. const extension = [ 'zip', 'tar' ].find((extension) => context.entries(extension).size > 0);
  716. if (extension) {
  717. const entries = new Map(Array.from(context.entries(extension)).filter((entry) => !entry[0].endsWith('/') && !entry[0].split('/').pop().startsWith('.')).slice());
  718. if (entries.size > 2 && Array.from(entries).every((entry) => entry[0].split('_').length > 0 && entry[1].peek(16).every((value) => value === 0x00))) {
  719. return new paddle.Entries(entries);
  720. }
  721. }
  722. return null;
  723. }
  724. constructor(data) {
  725. this._data = data;
  726. }
  727. get format() {
  728. return 'PaddlePaddle Weights';
  729. }
  730. get weights() {
  731. this._read();
  732. return this._weights;
  733. }
  734. _read() {
  735. if (!this._weights) {
  736. let rootFolder = null;
  737. for (const entry of this._data) {
  738. const name = entry[0];
  739. if (!name.startsWith('.') || name.startsWith('./')) {
  740. const parts = name.split('/');
  741. const folder = ((parts.length > 2 && parts[0] === '.') ? ('./' + parts[1] + '/') : (parts.length > 1 ? parts[0] + '/' : ''));
  742. rootFolder = (rootFolder === null) ? folder : (rootFolder !== '' && folder !== rootFolder) ? '' : folder;
  743. }
  744. }
  745. this._weights = new Map();
  746. for (const entry of this._data) {
  747. if (entry[0].startsWith(rootFolder)) {
  748. const name = entry[0].substring(rootFolder.length);
  749. const stream = entry[1];
  750. const tensor = paddle.Utility.openTensorDesc(stream);
  751. this._weights.set(name, tensor);
  752. }
  753. }
  754. }
  755. }
  756. };
  757. paddle.Pickle = class {
  758. static open(context) {
  759. const obj = context.open('pkl');
  760. const container = new paddle.Pickle(obj);
  761. return container.weights !== null ? container : null;
  762. }
  763. constructor(obj) {
  764. this._weights = null;
  765. if (obj && !Array.isArray(obj) && (obj instanceof Map || Object(obj) === obj)) {
  766. const entries = (obj) => {
  767. return obj instanceof Map ? Array.from(obj) : Object(obj) === obj ? Object.entries(obj) : [];
  768. };
  769. const filter = (obj) => {
  770. const list = [];
  771. if (obj && !Array.isArray(obj)) {
  772. for (const entry of entries(obj)) {
  773. const name = entry[0];
  774. if (name !== 'StructuredToParameterName@@') {
  775. let value = entry[1];
  776. value = value && Array.isArray(value) && value.length === 2 && value[0] === name ? value[1] : value;
  777. if (value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray') {
  778. list.push([ name, value ]);
  779. }
  780. }
  781. }
  782. }
  783. return list;
  784. };
  785. const weights = filter(obj);
  786. if (weights.length > 0) {
  787. this._weights = weights;
  788. }
  789. else {
  790. const list = entries(obj);
  791. if (list.filter((entry) => entry[0] !== 'StructuredToParameterName@@').length === 1) {
  792. const weights = filter(list[0][1]);
  793. if (weights.length > 0) {
  794. this._weights = weights;
  795. }
  796. }
  797. if (this._weights === null && list.filter((entry) => entry[0] === 'StructuredToParameterName@@').length > 0) {
  798. this._weights = [];
  799. }
  800. }
  801. }
  802. }
  803. get format() {
  804. return 'PaddlePaddle Pickle';
  805. }
  806. get weights() {
  807. if (this._weights && Array.isArray(this._weights)) {
  808. const weights = new Map();
  809. for (const entry of this._weights) {
  810. const name = entry[0];
  811. const value = entry[1];
  812. const type = new paddle.TensorType(value.dtype.__name__, new paddle.TensorShape(value.shape));
  813. const data = value.data;
  814. const tensor = new paddle.Tensor(type, data, 'NumPy Array');
  815. weights.set(name, tensor);
  816. }
  817. this._weights = weights;
  818. }
  819. return this._weights;
  820. }
  821. };
  822. paddle.NaiveBuffer = class {
  823. static open(context) {
  824. const stream = context.stream;
  825. if (stream && stream.length > 4) {
  826. const buffer = stream.peek(4);
  827. if (context.identifier === '__model__.nb' || context.identifier === 'param.nb') {
  828. if (buffer[0] > 2 || buffer[1] !== 0x00 || buffer[2] !== 0x76 || buffer[2] !== 0x32) {
  829. return new paddle.NaiveBuffer(stream, -1);
  830. }
  831. }
  832. if (buffer[1] === 0x00 && buffer[0] <= 2) {
  833. return new paddle.NaiveBuffer(stream, buffer[0]);
  834. }
  835. }
  836. return null;
  837. }
  838. constructor(stream, meta_version) {
  839. this.stream = stream;
  840. this.meta_version = meta_version;
  841. }
  842. get format() {
  843. this._read();
  844. return this._format;
  845. }
  846. get model() {
  847. this._read();
  848. return this._model;
  849. }
  850. get weights() {
  851. this._read();
  852. return this._weights;
  853. }
  854. _read() {
  855. if (this.stream) {
  856. const reader = new base.BinaryReader(this.stream);
  857. if (this.meta_version >= 2) {
  858. reader.skip(2);
  859. }
  860. delete this.stream;
  861. const decoder = new TextDecoder();
  862. const opt_version = reader.read(16);
  863. const version = decoder.decode(opt_version.slice(0, opt_version.indexOf(0x00)));
  864. this._format = 'Paddle Lite' + (version ? ' ' + version : '');
  865. const topo_size = reader.uint64();
  866. const openProgramDesc = (buffer) => {
  867. const reader = flatbuffers.BinaryReader.open(buffer);
  868. return paddle.schema.ProgramDesc.create(reader);
  869. };
  870. const openParamDesc = (buffer) => {
  871. const reader = flatbuffers.BinaryReader.open(buffer);
  872. return paddle.schema.ParamDesc.create(reader);
  873. };
  874. switch (this.meta_version) {
  875. case -1: {
  876. throw new paddle.Error('Paddle Lite naive buffer format is deprecated.');
  877. }
  878. case 0:
  879. case 1: {
  880. throw new paddle.Error("Paddle Lite meta format '" + this.meta_version.toString() + "' is deprecated.");
  881. }
  882. case 2: {
  883. const topo_data = new Uint8Array(topo_size);
  884. topo_data.set(reader.read(topo_size), 0);
  885. this._model = openProgramDesc(topo_data);
  886. reader.uint16(); // version
  887. reader.uint16(); // meta_size
  888. const header_size = reader.uint16();
  889. const params_size = reader.uint16();
  890. reader.uint32(); // max_tensor_size
  891. reader.skip(header_size - 6);
  892. this._weights = new Map();
  893. for (let i = 0; i < params_size; i++) {
  894. const total_size = reader.uint32();
  895. const offset = reader.uint32();
  896. const param_bytes = total_size - offset;
  897. const param_data = reader.read(param_bytes);
  898. const desc = openParamDesc(param_data);
  899. const data = desc.variable.data;
  900. const data_type = desc.variable.data_type;
  901. const dim = desc.variable.dim;
  902. const type = paddle.Utility.createTensorType(data_type, dim);
  903. const tensor = new paddle.Tensor(type, data);
  904. this._weights.set(desc.name, tensor);
  905. }
  906. break;
  907. }
  908. default: {
  909. throw new paddle.Error("Unsupported Paddle Lite naive buffer meta format '" + this.meta_version.toString() + "'.");
  910. }
  911. }
  912. }
  913. }
  914. };
  915. paddle.Utility = class {
  916. static createTensorType(data_type, shape) {
  917. if (!paddle.Utility._dataTypes) {
  918. const length = Math.max.apply(null, Object.entries(paddle.DataType).map((entry) => entry[1]));
  919. paddle.Utility._dataTypes = new Array(length);
  920. const map = new Map([ [ 'bool', 'boolean' ], [ 'bf16', 'bfloat16' ], [ 'fp16', 'float16' ], [ 'fp32', 'float32' ], [ 'fp64', 'float64' ] ]);
  921. for (const entry of Object.entries(paddle.DataType)) {
  922. const index = entry[1];
  923. const key = entry[0].toLowerCase();
  924. paddle.Utility._dataTypes[index] = map.has(key) ? map.get(key) : key;
  925. }
  926. }
  927. const dataType = data_type < paddle.Utility._dataTypes.length ? paddle.Utility._dataTypes[data_type] : '?';
  928. return new paddle.TensorType(dataType, new paddle.TensorShape(shape));
  929. }
  930. static openTensorDesc(stream) {
  931. const signature = stream.read(16);
  932. if (!signature.every((value) => value === 0x00)) {
  933. throw new paddle.Error('Invalid paddle.TensorDesc signature.');
  934. }
  935. const length = new base.BinaryReader(stream.read(4)).uint32();
  936. const buffer = stream.read(length);
  937. const reader = protobuf.BinaryReader.open(buffer);
  938. const tensorDesc = paddle.proto.VarType.TensorDesc.decode(reader);
  939. const size = tensorDesc.dims.reduce((a, b) => a * b.toNumber(), 1);
  940. let itemsize = 0;
  941. switch (tensorDesc.data_type) {
  942. case paddle.DataType.FP16: itemsize = 2; break;
  943. case paddle.DataType.FP32: itemsize = 4; break;
  944. case paddle.DataType.FP64: itemsize = 8; break;
  945. case paddle.DataType.INT8: itemsize = 1; break;
  946. case paddle.DataType.INT16: itemsize = 2; break;
  947. case paddle.DataType.INT32: itemsize = 4; break;
  948. case paddle.DataType.INT64: itemsize = 8; break;
  949. case paddle.DataType.UINT8: itemsize = 1; break;
  950. default: throw new paddle.Error("Invalid inference params data type '" + tensorDesc.data_type + "'.");
  951. }
  952. const type = paddle.Utility.createTensorType(tensorDesc.data_type, tensorDesc.dims);
  953. const data = stream.read(itemsize * size);
  954. return new paddle.Tensor(type, data);
  955. }
  956. };
  957. paddle.DataType = {
  958. BOOL: 0,
  959. INT16: 1,
  960. INT32: 2,
  961. INT64: 3,
  962. FP16: 4,
  963. FP32: 5,
  964. FP64: 6,
  965. LOD_TENSOR: 7,
  966. SELECTED_ROWS: 8,
  967. FEED_MINIBATCH: 9,
  968. FETCH_LIST: 10,
  969. STEP_SCOPES: 11,
  970. LOD_RANK_TABLE: 12,
  971. LOD_TENSOR_ARRAY: 13,
  972. PLACE_LIST: 14,
  973. READER: 15,
  974. RAW: 17,
  975. TUPLE: 18,
  976. SIZE_T: 19,
  977. UINT8: 20,
  978. INT8: 21,
  979. BF16: 22,
  980. COMPLEX64: 23,
  981. COMPLEX128: 24,
  982. };
  983. paddle.AttributeType = {
  984. INT: 0,
  985. FLOAT: 1,
  986. STRING: 2,
  987. INTS: 3,
  988. FLOATS: 4,
  989. STRINGS: 5,
  990. BOOLEAN: 6,
  991. BOOLEANS: 7,
  992. BLOCK: 8,
  993. LONG: 9,
  994. BLOCKS: 10,
  995. LONGS: 11,
  996. FLOAT64S: 12
  997. };
  998. paddle.Error = class extends Error {
  999. constructor(message) {
  1000. super(message);
  1001. this.name = 'Error loading PaddlePaddle model.';
  1002. }
  1003. };
  1004. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  1005. module.exports.ModelFactory = paddle.ModelFactory;
  1006. }