paddle.js 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897
  1. /* jshint esversion: 6 */
  2. var paddle = paddle || {};
  3. var protobuf = protobuf || require('./protobuf');
  4. paddle.ModelFactory = class {
  5. match(context) {
  6. const identifier = context.identifier;
  7. const extension = identifier.split('.').pop().toLowerCase();
  8. if (identifier === '__model__' || extension === '__model__' || extension === 'paddle' || extension === 'pdmodel') {
  9. const tags = context.tags('pb');
  10. if (tags.get(1) === 2) {
  11. return true;
  12. }
  13. }
  14. if (extension === 'pbtxt' || extension === 'txt') {
  15. const tags = context.tags('pbtxt');
  16. if (tags.has('blocks')) {
  17. return true;
  18. }
  19. }
  20. if (paddle.Container.open(context)) {
  21. return true;
  22. }
  23. const stream = context.stream;
  24. if (stream.length > 16 && stream.peek(16).every((value) => value === 0x00)) {
  25. return true;
  26. }
  27. return false;
  28. }
  29. open(context) {
  30. return paddle.Metadata.open(context).then((metadata) => {
  31. return context.require('./paddle-proto').then(() => {
  32. paddle.proto = protobuf.get('paddle').paddle.framework.proto;
  33. const stream = context.stream;
  34. const identifier = context.identifier;
  35. const parts = identifier.split('.');
  36. const extension = parts.pop().toLowerCase();
  37. const base = parts.join('.');
  38. const openProgram = (stream, extension) => {
  39. const program = {};
  40. program.format = 'PaddlePaddle';
  41. switch (extension) {
  42. case 'pbtxt':
  43. case 'txt': {
  44. try {
  45. const reader = protobuf.TextReader.open(stream);
  46. program.desc = paddle.proto.ProgramDesc.decodeText(reader);
  47. }
  48. catch (error) {
  49. const message = error && error.message ? error.message : error.toString();
  50. throw new paddle.Error('File text format is not paddle.ProgramDesc (' + message.replace(/\.$/, '') + ').');
  51. }
  52. break;
  53. }
  54. default: {
  55. try {
  56. const reader = protobuf.BinaryReader.open(stream);
  57. program.desc = paddle.proto.ProgramDesc.decode(reader);
  58. }
  59. catch (error) {
  60. const message = error && error.message ? error.message : error.toString();
  61. throw new paddle.Error('File format is not paddle.ProgramDesc (' + message.replace(/\.$/, '') + ').');
  62. }
  63. break;
  64. }
  65. }
  66. const programDesc = program.desc;
  67. if (programDesc.version && programDesc.version.version && programDesc.version.version.toNumber) {
  68. const version = programDesc.version.version.toNumber();
  69. if (version > 0) {
  70. const list = [ Math.floor(version / 1000000) % 1000, Math.floor(version / 1000) % 1000, version % 1000 ];
  71. if (list.slice(-1).pop() === 0) {
  72. list.pop();
  73. if (list.slice(-1).pop() === 0) {
  74. list.pop();
  75. }
  76. }
  77. program.format += ' v' + list.map((item) => item.toString()).join('.');
  78. }
  79. }
  80. const variables = new Set();
  81. for (const block of programDesc.blocks) {
  82. const blockVars = new Set();
  83. for (const variable of block.vars) {
  84. if (variable.persistable && variable.type &&
  85. variable.type.type != paddle.proto.VarType.Type.FETCH_LIST &&
  86. variable.type.type != paddle.proto.VarType.Type.FEED_MINIBATCH) {
  87. blockVars.add(variable.name);
  88. }
  89. }
  90. for (const op of block.ops) {
  91. for (const input of op.inputs) {
  92. for (const argument of input.arguments) {
  93. if (blockVars.has(argument)) {
  94. variables.add(argument);
  95. }
  96. }
  97. }
  98. }
  99. }
  100. program.vars = Array.from(variables).sort();
  101. return program;
  102. };
  103. const loadParams = (metadata, program, stream) => {
  104. const tensors = new Map();
  105. while (stream.position < stream.length) {
  106. tensors.set(program.vars.shift(), new paddle.Tensor(null, stream));
  107. }
  108. return new paddle.Model(metadata, program.format, program.desc, tensors);
  109. };
  110. const container = paddle.Container.open(context);
  111. if (container) {
  112. return new paddle.Model(metadata, container.format, null, container.weights);
  113. }
  114. else if (stream.length > 16 && stream.peek(16).every((value) => value === 0x00)) {
  115. const file = identifier !== 'params' ? base + '.pdmodel' : 'model';
  116. return context.request(file, null).then((stream) => {
  117. const program = openProgram(stream, '');
  118. return loadParams(metadata, program, context.stream);
  119. });
  120. }
  121. else {
  122. const program = openProgram(context.stream, extension);
  123. const loadEntries = (context, program) => {
  124. const promises = program.vars.map((name) => context.request(name, null));
  125. const tensors = new Map();
  126. return Promise.all(promises).then((streams) => {
  127. for (let i = 0; i < program.vars.length; i++) {
  128. tensors.set(program.vars[i], new paddle.Tensor(null, streams[i]));
  129. }
  130. return new paddle.Model(metadata, program.format, program.desc, tensors);
  131. }).catch((/* err */) => {
  132. return new paddle.Model(metadata, program.format, program.desc, tensors);
  133. });
  134. };
  135. if (extension === 'pdmodel') {
  136. return context.request(base + '.pdiparams', null).then((stream) => {
  137. return loadParams(metadata, program, stream);
  138. }).catch((/* err */) => {
  139. return loadEntries(context, program);
  140. });
  141. }
  142. if (identifier === 'model') {
  143. return context.request('params', null).then((stream) => {
  144. return loadParams(metadata, program, stream);
  145. }).catch((/* err */) => {
  146. return loadEntries(context, program);
  147. });
  148. }
  149. return loadEntries(context, program);
  150. }
  151. });
  152. });
  153. }
  154. };
  155. paddle.Model = class {
  156. constructor(metadata, format, programDesc, tensors) {
  157. this._format = format;
  158. this._graphs = programDesc ?
  159. programDesc.blocks.map((block) => new paddle.Graph(metadata, block, tensors)) :
  160. [ new paddle.Graph(metadata, null, tensors) ];
  161. }
  162. get format() {
  163. return this._format;
  164. }
  165. get graphs() {
  166. return this._graphs;
  167. }
  168. };
  169. paddle.Graph = class {
  170. constructor(metadata, block, tensors) {
  171. this._nodes = [];
  172. this._inputs = [];
  173. this._outputs = [];
  174. if (block) {
  175. this._name = block.idx.toString();
  176. const args = new Map();
  177. for (const variable of block.vars) {
  178. const type = variable.type && variable.type.type && variable.type.lod_tensor && variable.type.lod_tensor.tensor ? paddle.Utility.createTensorType(variable.type.lod_tensor.tensor) : null;
  179. const tensor = variable.persistable && variable.type && variable.type.type != paddle.proto.VarType.Type.FETCH_LIST && variable.type.type != paddle.proto.VarType.Type.FEED_MINIBATCH ? (tensors.get(variable.name) || new paddle.Tensor(type)) : null;
  180. args.set(variable.name, new paddle.Argument(variable.name, type, tensor));
  181. }
  182. const scope = {};
  183. for (let i = 0; i < block.ops.length; i++) {
  184. for (const input of block.ops[i].inputs) {
  185. input.arguments = input.arguments.map((argument) => scope[argument] ? scope[argument] : argument);
  186. }
  187. for (const output of block.ops[i].outputs) {
  188. output.arguments = output.arguments.map((argument) => {
  189. if (scope[argument]) {
  190. const next = argument + '\n' + i.toString(); // custom argument id
  191. scope[argument] = next;
  192. return next;
  193. }
  194. scope[argument] = argument;
  195. return argument;
  196. });
  197. }
  198. }
  199. for (const op of block.ops) {
  200. for (const input of op.inputs) {
  201. for (const argument of input.arguments) {
  202. const name = argument;
  203. if (!args.has(name)) {
  204. args.set(name, new paddle.Argument(name, null, null));
  205. }
  206. }
  207. }
  208. for (const output of op.outputs) {
  209. for (const argument of output.arguments) {
  210. const name = argument;
  211. if (!args.has(name)) {
  212. args.set(name, new paddle.Argument(name, null, null));
  213. }
  214. }
  215. }
  216. }
  217. let lastNode = null;
  218. let lastOutput = null;
  219. for (const op of block.ops) {
  220. if (op.type == 'feed') {
  221. const inputName = op.attrs.filter((attr) => attr.name == 'col')[0].i.toString();
  222. this._inputs.push(new paddle.Parameter(inputName, op.outputs[0].arguments.map((id) => args.get(id))));
  223. }
  224. else if (op.type == 'fetch') {
  225. const outputName = op.attrs.filter((attr) => attr.name == 'col')[0].i.toString();
  226. this._outputs.push(new paddle.Parameter(outputName, op.inputs[0].arguments.map((id) => args.get(id))));
  227. }
  228. else {
  229. const node = new paddle.Node(metadata, op, args);
  230. if (op.inputs.length == 1 && op.inputs[0].arguments.length == 1 &&
  231. op.outputs.length >= 1 && op.outputs[0].arguments.length == 1 &&
  232. op.inputs[0].arguments[0].split('\n').shift() == op.outputs[0].arguments[0].split('\n').shift() &&
  233. lastNode &&
  234. lastOutput == op.inputs[0].arguments[0].split('\n').shift()) {
  235. lastNode.chain.push(node);
  236. }
  237. else {
  238. this._nodes.push(node);
  239. lastNode = null;
  240. lastOutput = null;
  241. if (op.outputs.length == 1 && op.outputs[0].arguments.length == 1) {
  242. lastNode = node;
  243. lastOutput = op.outputs[0].arguments[0].split('\n').shift();
  244. }
  245. }
  246. }
  247. }
  248. }
  249. else {
  250. const args = new Map();
  251. const ops = new Map();
  252. for (const pair of tensors) {
  253. const name = pair[0];
  254. const tensor = pair[1];
  255. args.set(name, new paddle.Argument(name, tensor.type, tensor));
  256. const separator = [ '.', '_' ].find((separator) => name.split(separator).length > 1);
  257. const parts = name.split(separator);
  258. const parameter_name = parts.pop();
  259. const op_name = parts.join(separator);
  260. if (!ops.has(op_name)) {
  261. ops.set(op_name, { name: op_name, type: 'Weights', inputs: [] });
  262. }
  263. const op = ops.get(op_name);
  264. op.inputs.push({ parameter: parameter_name, arguments: [ name ] });
  265. }
  266. for (const pair of ops) {
  267. const op = pair[1];
  268. this._nodes.push(new paddle.Node(metadata, op, args));
  269. }
  270. }
  271. }
  272. get name() {
  273. return this._name;
  274. }
  275. get inputs() {
  276. return this._inputs;
  277. }
  278. get outputs() {
  279. return this._outputs;
  280. }
  281. get nodes() {
  282. return this._nodes;
  283. }
  284. };
  285. paddle.Parameter = class {
  286. constructor(name, args) {
  287. this._name = name;
  288. this._arguments = args;
  289. }
  290. get name() {
  291. return this._name;
  292. }
  293. get visible() {
  294. return true;
  295. }
  296. get arguments() {
  297. return this._arguments;
  298. }
  299. };
  300. paddle.Argument = class {
  301. constructor(name, type, initializer) {
  302. if (typeof name !== 'string') {
  303. throw new paddle.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  304. }
  305. this._name = name;
  306. this._type = type || null;
  307. this._initializer = initializer || null;
  308. }
  309. get name() {
  310. return this._name;
  311. }
  312. get type() {
  313. if (this._type) {
  314. return this._type;
  315. }
  316. if (this._initializer) {
  317. return this._initializer.type;
  318. }
  319. return null;
  320. }
  321. get initializer() {
  322. return this._initializer;
  323. }
  324. };
  325. paddle.Node = class {
  326. constructor(metadata, op, args) {
  327. const type = op.type;
  328. this._type = metadata.type(type) || { name: type };
  329. this._name = op.name || '';
  330. this._attributes = [];
  331. this._inputs = [];
  332. this._outputs = [];
  333. this._chain = [];
  334. if (op.attrs) {
  335. for (const attr of op.attrs) {
  336. const schema = metadata.attribute(type, this._name);
  337. this._attributes.push(new paddle.Attribute(schema, attr));
  338. }
  339. }
  340. if (op.inputs) {
  341. for (const input of op.inputs) {
  342. if (input.arguments.length > 0) {
  343. this._inputs.push(new paddle.Parameter(input.parameter, input.arguments.map((name) => args.get(name))));
  344. }
  345. }
  346. }
  347. if (op.outputs) {
  348. for (const output of op.outputs) {
  349. if (output.arguments.length > 0) {
  350. this._outputs.push(new paddle.Parameter(output.parameter, output.arguments.map((name) => args.get(name))));
  351. }
  352. }
  353. }
  354. this._update(this._inputs, 'X');
  355. this._update(this._inputs, 'Input');
  356. this._update(this._outputs, 'Y');
  357. this._update(this._outputs, 'Out');
  358. }
  359. get type() {
  360. return this._type;
  361. }
  362. get name() {
  363. return this._name;
  364. }
  365. get attributes() {
  366. return this._attributes;
  367. }
  368. get inputs() {
  369. return this._inputs;
  370. }
  371. get outputs() {
  372. return this._outputs;
  373. }
  374. get chain() {
  375. return this._chain;
  376. }
  377. _update(list, name) {
  378. let item = null;
  379. for (let i = 0; i < list.length; i++) {
  380. if (list[i].name == name) {
  381. item = list[i];
  382. list.splice(i, 1);
  383. break;
  384. }
  385. }
  386. if (item) {
  387. list.splice(0, 0, item);
  388. }
  389. }
  390. };
  391. paddle.Attribute = class {
  392. constructor(schema, attr) {
  393. this._name = attr.name;
  394. this._value = '?';
  395. switch (attr.type) {
  396. case paddle.proto.AttrType.STRING:
  397. this._type = 'string';
  398. this._value = attr.s;
  399. break;
  400. case paddle.proto.AttrType.STRINGS:
  401. this._type = 'string[]';
  402. this._value = attr.strings;
  403. break;
  404. case paddle.proto.AttrType.BOOLEAN:
  405. this._type = 'boolean';
  406. this._value = attr.b;
  407. break;
  408. case paddle.proto.AttrType.BOOLEANS:
  409. this._type = 'boolean[]';
  410. this._value = attr.bools;
  411. break;
  412. case paddle.proto.AttrType.FLOAT:
  413. this._type = 'float32';
  414. this._value = attr.f;
  415. break;
  416. case paddle.proto.AttrType.FLOATS:
  417. this._type = 'float[]';
  418. this._value = attr.floats;
  419. break;
  420. case paddle.proto.AttrType.INT:
  421. this._type = 'int32';
  422. this._value = attr.i;
  423. break;
  424. case paddle.proto.AttrType.INTS:
  425. this._type = 'int32[]';
  426. this._value = attr.ints;
  427. break;
  428. case paddle.proto.AttrType.LONG:
  429. this._type = 'int64';
  430. break;
  431. case paddle.proto.AttrType.LONGS:
  432. this._type = 'int64[]';
  433. break;
  434. default:
  435. break;
  436. }
  437. switch (this._name) {
  438. case 'use_mkldnn':
  439. case 'use_cudnn':
  440. case 'op_callstack':
  441. case 'op_role':
  442. case 'op_role_var':
  443. case 'op_namescope':
  444. case 'is_test':
  445. this._visible = false;
  446. break;
  447. }
  448. if (schema) {
  449. if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  450. const defaultValue = schema.default;
  451. const value = this._value;
  452. if (defaultValue == value) {
  453. this._visible = false;
  454. }
  455. else if (Array.isArray(value) && Array.isArray(defaultValue) && value.length == defaultValue.length) {
  456. if (value.every((item, index) => { return item == defaultValue[index]; })) {
  457. this._visible = false;
  458. }
  459. }
  460. }
  461. }
  462. }
  463. get name() {
  464. return this._name;
  465. }
  466. get type() {
  467. return this._type;
  468. }
  469. get value() {
  470. return this._value;
  471. }
  472. get visible() {
  473. return this._visible == false ? false : true;
  474. }
  475. };
  476. paddle.Tensor = class {
  477. constructor(type, data) {
  478. this._type = type;
  479. if (data && !Array.isArray(data)) {
  480. if (data.__class__ && data.__class__.__module__ === 'numpy' && data.__class__.__name__ === 'ndarray') {
  481. this._type = new paddle.TensorType(data.dtype.name, new paddle.TensorShape(data.shape));
  482. this._data = data.data;
  483. this._kind = 'NumPy Array';
  484. }
  485. else {
  486. const uint32 = (stream) => {
  487. const buffer = stream.read(4);
  488. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  489. return view.getUint32(0, true);
  490. };
  491. const stream = data;
  492. const signature = stream.read(16);
  493. if (!signature.every((value) => value === 0x00)) {
  494. throw new paddle.Error('Invalid paddle.TensorDesc signature.');
  495. }
  496. const length = uint32(stream);
  497. const buffer = stream.read(length);
  498. const reader = protobuf.BinaryReader.open(buffer);
  499. const tensorDesc = paddle.proto.VarType.TensorDesc.decode(reader);
  500. const size = tensorDesc.dims.reduce((a, b) => a * b.toNumber(), 1);
  501. let itemsize = 0;
  502. switch (tensorDesc.data_type) {
  503. case paddle.proto.VarType.Type.FP32: itemsize = 4; break;
  504. default: throw new paddle.Error("Invalid inference params data type '" + tensorDesc.data_type + "'.");
  505. }
  506. this._type = paddle.Utility.createTensorType(tensorDesc);
  507. this._data = stream.read(itemsize * size);
  508. }
  509. }
  510. }
  511. get kind() {
  512. return this._kind;
  513. }
  514. get type() {
  515. return this._type;
  516. }
  517. get state() {
  518. return this._context().state || null;
  519. }
  520. get value() {
  521. const context = this._context();
  522. if (context.state) {
  523. return null;
  524. }
  525. context.limit = Number.MAX_SAFE_INTEGER;
  526. return this._decode(context, 0);
  527. }
  528. toString() {
  529. const context = this._context();
  530. if (context.state) {
  531. return '';
  532. }
  533. context.limit = 10000;
  534. const value = this._decode(context, 0);
  535. return paddle.Tensor._stringify(value, '', ' ');
  536. }
  537. _context() {
  538. const context = {};
  539. context.index = 0;
  540. context.count = 0;
  541. context.state = null;
  542. if (!this._data) {
  543. context.state = 'Tensor data is empty.';
  544. return context;
  545. }
  546. if (!this._type) {
  547. context.state = 'Tensor has no data type.';
  548. return context;
  549. }
  550. context.dataType = this._type.dataType;
  551. context.shape = this._type.shape.dimensions;
  552. context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
  553. switch (context.dataType) {
  554. case 'float32':
  555. case 'int32':
  556. case 'int64':
  557. break;
  558. default:
  559. context.state = "Tensor data type '" + context.dataType + "' is not implemented.";
  560. break;
  561. }
  562. return context;
  563. }
  564. _decode(context, dimension) {
  565. const shape = context.shape.length !== 0 ? context.shape : [ 1 ];
  566. const results = [];
  567. const size = shape[dimension];
  568. if (dimension == shape.length - 1) {
  569. for (let i = 0; i < size; i++) {
  570. if (context.count > context.limit) {
  571. results.push('...');
  572. return results;
  573. }
  574. switch (context.dataType) {
  575. case 'float32':
  576. results.push(context.view.getFloat32(context.index, true));
  577. context.index += 4;
  578. context.count++;
  579. break;
  580. case 'int32':
  581. results.push(context.view.getInt32(context.index, true));
  582. context.index += 4;
  583. context.count++;
  584. break;
  585. case 'int64':
  586. results.push(context.view.getInt64(context.index, true));
  587. context.index += 8;
  588. context.count++;
  589. break;
  590. }
  591. }
  592. }
  593. else {
  594. for (let j = 0; j < size; j++) {
  595. if (context.count > context.limit) {
  596. results.push('...');
  597. return results;
  598. }
  599. results.push(this._decode(context, dimension + 1));
  600. }
  601. }
  602. if (context.shape.length == 0) {
  603. return results[0];
  604. }
  605. return results;
  606. }
  607. static _stringify(value, indentation, indent) {
  608. if (Array.isArray(value)) {
  609. const result = [];
  610. result.push(indentation + '[');
  611. const items = value.map((item) => paddle.Tensor._stringify(item, indentation + indent, indent));
  612. if (items.length > 0) {
  613. result.push(items.join(',\n'));
  614. }
  615. result.push(indentation + ']');
  616. return result.join('\n');
  617. }
  618. if (typeof value == 'string') {
  619. return indentation + value;
  620. }
  621. if (value == Infinity) {
  622. return indentation + 'Infinity';
  623. }
  624. if (value == -Infinity) {
  625. return indentation + '-Infinity';
  626. }
  627. if (isNaN(value)) {
  628. return indentation + 'NaN';
  629. }
  630. return indentation + value.toString();
  631. }
  632. };
  633. paddle.TensorType = class {
  634. constructor(dataType, shape) {
  635. this._dataType = dataType;
  636. this._shape = shape;
  637. }
  638. get dataType() {
  639. return this._dataType;
  640. }
  641. get shape() {
  642. return this._shape;
  643. }
  644. get denotation() {
  645. return this._denotation;
  646. }
  647. toString() {
  648. return this._dataType + this._shape.toString();
  649. }
  650. };
  651. paddle.TensorShape = class {
  652. constructor(dimensions) {
  653. dimensions = dimensions.map((dimension) => Number.isInteger(dimension) ? dimension : dimension.toNumber());
  654. this._dimensions = dimensions.map((dimension) => {
  655. return dimension != -1 ? dimension : '?';
  656. });
  657. }
  658. get dimensions() {
  659. return this._dimensions;
  660. }
  661. toString() {
  662. return (this._dimensions && this._dimensions.length) ? ('[' + this._dimensions.join(',') + ']') : '';
  663. }
  664. };
  665. paddle.Utility = class {
  666. static createTensorType(desc) {
  667. if (!paddle.Utility._dataTypes) {
  668. const length = Math.max.apply(null, Object.values(paddle.proto.VarType.Type));
  669. paddle.Utility._dataTypes = new Array(length);
  670. for (const key of Object.keys(paddle.proto.VarType.Type)) {
  671. const index = paddle.proto.VarType.Type[key];
  672. let name = key.toLowerCase();
  673. switch (name) {
  674. case 'bool': name = 'boolean'; break;
  675. case 'bf16': name = 'bfloat16'; break;
  676. case 'fp16': name = 'float16'; break;
  677. case 'fp32': name = 'float32'; break;
  678. case 'fp64': name = 'float64'; break;
  679. }
  680. paddle.Utility._dataTypes[index] = name;
  681. }
  682. }
  683. const dataType = desc.data_type < paddle.Utility._dataTypes.length ? paddle.Utility._dataTypes[desc.data_type] : '?';
  684. return new paddle.TensorType(dataType, new paddle.TensorShape(desc.dims));
  685. }
  686. };
  687. paddle.Container = class {
  688. static open(context) {
  689. const extension = [ 'zip', 'tar' ].find((extension) => context.entries(extension).size > 0);
  690. if (extension) {
  691. const entries = new Map(Array.from(context.entries(extension)).filter((entry) => !entry[0].endsWith('/') && !entry[0].split('/').pop().startsWith('.')).slice());
  692. if (entries.size > 2 && Array.from(entries).every((entry) => entry[0].split('_').length > 0 && entry[1].peek(16).every((value) => value === 0x00))) {
  693. return new paddle.Container('entries', entries);
  694. }
  695. }
  696. const obj = context.open('pkl');
  697. if (obj && !Array.isArray(obj) && Object(obj) === obj) {
  698. return new paddle.Container('pdparams', obj);
  699. }
  700. return null;
  701. }
  702. constructor(format, data) {
  703. this._format = format;
  704. this._data = data;
  705. }
  706. get format() {
  707. switch (this._format) {
  708. case 'entries':
  709. return 'PaddlePaddle Weights';
  710. case 'pdparams':
  711. return 'PaddlePaddle Pickle';
  712. }
  713. return null;
  714. }
  715. get model() {
  716. this._initialize();
  717. return this._model;
  718. }
  719. get weights() {
  720. this._initialize();
  721. return this._weights;
  722. }
  723. _initialize() {
  724. if (!this._weights) {
  725. switch (this._format) {
  726. case 'entries': {
  727. let rootFolder = null;
  728. for (const entry of this._data) {
  729. const name = entry[0];
  730. if (name.startsWith('.') && !name.startsWith('./')) {
  731. continue;
  732. }
  733. const parts = name.split('/');
  734. const folder = ((parts.length > 2 && parts[0] === '.') ? ('./' + parts[1] + '/') : (parts.length > 1 ? parts[0] + '/' : ''));
  735. rootFolder = (rootFolder === null) ? folder : (rootFolder !== '' && folder !== rootFolder) ? '' : folder;
  736. }
  737. this._weights = new Map();
  738. for (const entry of this._data) {
  739. if (entry[0].startsWith(rootFolder)) {
  740. const name = entry[0].substring(rootFolder.length);
  741. const stream = entry[1];
  742. const tensor = new paddle.Tensor(null, stream);
  743. this._weights.set(name, tensor);
  744. }
  745. }
  746. break;
  747. }
  748. case 'pdparams': {
  749. const map = null; // this._data['StructuredToParameterName@@'];
  750. this._weights = new Map();
  751. for (const key of Object.keys(this._data)) {
  752. const value = this._data[key];
  753. if (value && !Array.isArray(value) && value.__class__ && value.__class__.__module__ === 'numpy' && value.__class__.__name__ === 'ndarray') {
  754. const name = map ? map[key] : key;
  755. this._weights.set(name, new paddle.Tensor(null, value));
  756. }
  757. }
  758. break;
  759. }
  760. }
  761. delete this._format;
  762. }
  763. }
  764. };
  765. paddle.Metadata = class {
  766. static open(context) {
  767. if (paddle.Metadata._metadata) {
  768. return Promise.resolve(paddle.Metadata._metadata);
  769. }
  770. return context.request('paddle-metadata.json', 'utf-8', null).then((data) => {
  771. paddle.Metadata._metadata = new paddle.Metadata(data);
  772. return paddle.Metadata._metadata;
  773. }).catch(() => {
  774. paddle.Metadata._metadata = new paddle.Metadata(null);
  775. return paddle.Metadata._metadata;
  776. });
  777. }
  778. constructor(data) {
  779. this._map = new Map();
  780. this._attributeCache = new Map();
  781. if (data) {
  782. const metadata = JSON.parse(data);
  783. this._map = new Map(metadata.map((item) => [ item.name, item ]));
  784. }
  785. }
  786. type(name) {
  787. return this._map.get(name) || null;
  788. }
  789. attribute(type, name) {
  790. let map = this._attributeCache.get(type);
  791. if (!map) {
  792. map = new Map();
  793. const metadata = this.type(type);
  794. if (metadata && metadata.attributes && metadata.attributes.length > 0) {
  795. for (const attribute of metadata.attributes) {
  796. map.set(attribute.name, attribute);
  797. }
  798. }
  799. this._attributeCache.set(type, map);
  800. }
  801. return map.get(name) || null;
  802. }
  803. };
  804. paddle.Error = class extends Error {
  805. constructor(message) {
  806. super(message);
  807. this.name = 'Error loading PaddlePaddle model.';
  808. }
  809. };
  810. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  811. module.exports.ModelFactory = paddle.ModelFactory;
  812. }