caffe2.js 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. var caffe2 = {};
  2. var protobuf = require('./protobuf');
  3. caffe2.ModelFactory = class {
  4. match(context) {
  5. const identifier = context.identifier.toLowerCase();
  6. const extension = identifier.split('.').pop().toLowerCase();
  7. if (extension === 'pb') {
  8. const tags = context.tags('pb');
  9. if (tags.size > 0 &&
  10. Array.from(tags.keys()).every((tag) => tag <= 9) &&
  11. Array.from(tags.values()).every((type) => type <= 4)) {
  12. if (tags.size === 1 && tags.get(2) === 2 && identifier.endsWith('saved_model.pb')) {
  13. return undefined;
  14. }
  15. const schema = [[1,2],[2,2],[3,2],[4,0],[5,2],[6,2],[7,2],[8,2],[9,2]];
  16. if (schema.every((pair) => !tags.has(pair[0]) || tags.get(pair[0]) === pair[1])) {
  17. const stream = context.stream;
  18. if (stream.length > 3) {
  19. const buffer = stream.peek(Math.min(stream.length, 67));
  20. if (buffer[0] == 0x0A) {
  21. const size = buffer[1];
  22. if (size < 64 &&
  23. buffer.length > 2 + size + 1 &&
  24. buffer.slice(2, 2 + size).every((c) => c >= 32 && c <= 127) &&
  25. buffer[2 + size] == 0x12) {
  26. return 'caffe2.pb';
  27. }
  28. }
  29. if (buffer[0] == 0x12) {
  30. return 'caffe2.pb';
  31. }
  32. }
  33. }
  34. }
  35. }
  36. if (extension === 'pbtxt' || extension === 'prototxt') {
  37. const tags = context.tags('pbtxt');
  38. if (tags.has('op') && !tags.has('op.attr') && !tags.has('op.graph_op_name') && !tags.has('op.endpoint')) {
  39. return 'caffe2.pbtxt';
  40. }
  41. }
  42. return undefined;
  43. }
  44. async open(context, target) {
  45. await context.require('./caffe2-proto');
  46. const metadata = await context.metadata('caffe2-metadata.json');
  47. const identifier = context.identifier;
  48. const parts = identifier.split('.');
  49. const extension = parts.pop().toLowerCase();
  50. const base = parts.join('.');
  51. switch (target) {
  52. case 'caffe2.pbtxt': {
  53. const openText = (predictBuffer, initBuffer, initTextFormat) => {
  54. let predict_net = null;
  55. let init_net = null;
  56. try {
  57. caffe2.proto = protobuf.get('caffe2').caffe2;
  58. const reader = protobuf.TextReader.open(predictBuffer);
  59. reader.field = function(tag, message) {
  60. if (message instanceof caffe2.proto.DeviceOption) {
  61. message[tag] = this.read();
  62. return;
  63. }
  64. throw new Error("Unknown field '" + tag + "'" + this.location());
  65. };
  66. predict_net = caffe2.proto.NetDef.decodeText(reader);
  67. } catch (error) {
  68. const message = error && error.message ? error.message : error.toString();
  69. throw new caffe2.Error('File text format is not caffe2.NetDef (' + message.replace(/\.$/, '') + ').');
  70. }
  71. try {
  72. caffe2.proto = protobuf.get('caffe2').caffe2;
  73. if (initBuffer) {
  74. if (initTextFormat) {
  75. const reader = protobuf.TextReader.open(initBuffer);
  76. init_net = caffe2.proto.NetDef.decodeText(reader);
  77. } else {
  78. const reader = protobuf.BinaryReader.open(initBuffer);
  79. init_net = caffe2.proto.NetDef.decode(reader);
  80. }
  81. }
  82. } catch (error) {
  83. // continue regardless of error
  84. }
  85. return new caffe2.Model(metadata, predict_net, init_net);
  86. };
  87. if (base.toLowerCase().endsWith('init_net') || base.toLowerCase().startsWith('init_net')) {
  88. try {
  89. const stream = await context.request(identifier.replace('init_net', 'predict_net'), null);
  90. const buffer = stream.read();
  91. return openText(buffer, context.stream.peek(), true);
  92. } catch (error) {
  93. return openText(context.stream.peek(), null, true);
  94. }
  95. }
  96. if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
  97. try {
  98. const stream = await context.request(identifier.replace('predict_net', 'init_net').replace(/\.pbtxt/, '.pb'), null);
  99. const buffer = stream.read();
  100. return openText(context.stream.peek(), buffer, false);
  101. } catch (error) {
  102. try {
  103. const stream = await context.request(identifier.replace('predict_net', 'init_net'), null);
  104. const buffer = stream.read();
  105. return openText(context.stream.peek(), buffer, true);
  106. } catch (error) {
  107. return openText(context.stream.peek(), null, true);
  108. }
  109. }
  110. }
  111. try {
  112. const stream = await context.request(base + '_init.pb', null);
  113. const buffer = stream.read();
  114. return openText(context.stream.peek(), buffer, false);
  115. } catch (error) {
  116. return openText(context.stream.peek(), null, false);
  117. }
  118. }
  119. case 'caffe2.pb': {
  120. const openBinary = (predictBuffer, initBuffer) => {
  121. let predict_net = null;
  122. let init_net = null;
  123. try {
  124. caffe2.proto = protobuf.get('caffe2').caffe2;
  125. const reader = protobuf.BinaryReader.open(predictBuffer);
  126. predict_net = caffe2.proto.NetDef.decode(reader);
  127. } catch (error) {
  128. const message = error && error.message ? error.message : error.toString();
  129. throw new caffe2.Error('File format is not caffe2.NetDef (' + message.replace(/\.$/, '') + ').');
  130. }
  131. try {
  132. if (initBuffer) {
  133. caffe2.proto = protobuf.get('caffe2').caffe2;
  134. const reader = protobuf.BinaryReader.open(initBuffer);
  135. init_net = caffe2.proto.NetDef.decode(reader);
  136. }
  137. } catch (error) {
  138. // continue regardless of error
  139. }
  140. return new caffe2.Model(metadata, predict_net, init_net);
  141. };
  142. if (base.toLowerCase().endsWith('init_net')) {
  143. try {
  144. const stream = await context.request(base.replace(/init_net$/, '') + 'predict_net.' + extension, null);
  145. const buffer = stream.read();
  146. return openBinary(buffer, context.stream.peek());
  147. } catch (error) {
  148. return openBinary(context.stream.peek(), null);
  149. }
  150. }
  151. if (base.toLowerCase().endsWith('_init')) {
  152. try {
  153. const stream = await context.request(base.replace(/_init$/, '') + '.' + extension, null);
  154. const buffer = stream.read();
  155. return openBinary(buffer, context.stream.peek());
  156. } catch (error) {
  157. return openBinary(context.stream.peek(), null);
  158. }
  159. }
  160. if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
  161. try {
  162. const stream = await context.request(identifier.replace('predict_net', 'init_net'), null);
  163. const buffer = stream.read();
  164. return openBinary(context.stream.peek(), buffer);
  165. } catch (error) {
  166. return openBinary(context.stream.peek(), null);
  167. }
  168. }
  169. try {
  170. const stream = await context.request(base + '_init.' + extension, null);
  171. const buffer = stream.read();
  172. return openBinary(context.stream.peek(), buffer);
  173. } catch (error) {
  174. return openBinary(context.stream.peek(), null);
  175. }
  176. }
  177. default: {
  178. throw new caffe2.Error("Unsupported Caffe2 format '" + target + "'.");
  179. }
  180. }
  181. }
  182. };
  183. caffe2.Model = class {
  184. constructor(metadata, predict_net, init_net) {
  185. this._domain = predict_net.domain || null;
  186. const graph = new caffe2.Graph(metadata, predict_net, init_net);
  187. this._graphs = [ graph ];
  188. }
  189. get format() {
  190. return 'Caffe2';
  191. }
  192. get domain() {
  193. return this._domain;
  194. }
  195. get graphs() {
  196. return this._graphs;
  197. }
  198. };
  199. caffe2.Graph = class {
  200. constructor(metadata, netDef, init) {
  201. this._name = netDef.name || '';
  202. this._type = netDef.type || '';
  203. this._nodes = [];
  204. const initializers = new Set();
  205. const tensors = new Map();
  206. for (const name of netDef.external_input) {
  207. tensors.set(name, new caffe2.Tensor(name, {}));
  208. }
  209. if (init) {
  210. const dataTypes = new Map([
  211. [ 'GivenTensorFill', 'float32' ],
  212. [ 'GivenTensorDoubleFill', 'float64' ],
  213. [ 'GivenTensorBoolFill', 'boolean' ],
  214. [ 'GivenTensorByteStringToUInt8Fill', 'uint8' ],
  215. [ 'GivenTensorInt16Fill', 'int16' ],
  216. [ 'GivenTensorSInt16Fill', 'int16' ],
  217. [ 'GivenTensorIntFill', 'int32' ],
  218. [ 'GivenTensorInt64Fill', 'int64' ],
  219. [ 'GivenTensorStringFill', 'string' ],
  220. [ 'Int8GivenIntTensorFill', 'int32' ],
  221. [ 'Int8GivenTensorFill', 'int8' ],
  222. [ 'XavierFill', null ],
  223. [ 'ConstantFill', null ]
  224. ]);
  225. for (const op of init.op) {
  226. if (op.output && op.output.length == 1) {
  227. const name = op.output[0];
  228. const tensor = {};
  229. for (const arg of op.arg) {
  230. tensor[arg.name] = arg;
  231. }
  232. if (!dataTypes.has(op.type)) {
  233. throw new caffe2.Error("Unsupported init op '" + op.type + "'.");
  234. }
  235. tensor.dataType = dataTypes.get(op.type);
  236. if (tensor.values && tensor.values.floats && (tensor.values.floats.length !== 1 || tensor.values.floats[0] !== 0)) {
  237. initializers.add(name);
  238. }
  239. tensors.set(name, new caffe2.Tensor(name, tensor));
  240. }
  241. }
  242. }
  243. const scope = {};
  244. let index = 0;
  245. for (const op of netDef.op) {
  246. op.input = op.input.map((input) => scope[input] ? scope[input] : input);
  247. op.output = op.output.map((output) => {
  248. if (scope[output]) {
  249. const next = output + '\n' + index.toString(); // custom argument id
  250. scope[output] = next;
  251. return next;
  252. }
  253. scope[output] = output;
  254. return output;
  255. });
  256. index++;
  257. }
  258. const args = new Map();
  259. const arg = (name, type, tensor) => {
  260. if (!args.has(name)) {
  261. args.set(name, new caffe2.Value(name, type || null, tensor || null));
  262. } else if (type || tensor) {
  263. throw new caffe2.Value("Duplicate value '" + name + "'.");
  264. }
  265. return args.get(name);
  266. };
  267. for (const op of netDef.op) {
  268. let index = 0;
  269. for (const name of op.input) {
  270. if (index > 0 && tensors.has(name)) {
  271. if (!args.has(name)) {
  272. args.set(name, new caffe2.Value(name, null, tensors.get(name)));
  273. }
  274. initializers.add(name);
  275. }
  276. index++;
  277. }
  278. }
  279. for (const op of netDef.op) {
  280. for (const name of op.output) {
  281. if (tensors.has(name)) {
  282. initializers.add(name);
  283. }
  284. }
  285. }
  286. let lastNode = null;
  287. let lastOutput = null;
  288. for (const op of netDef.op) {
  289. const node = new caffe2.Node(metadata, op, arg);
  290. if (op.input.length == 1 &&
  291. op.output.length >= 1 &&
  292. op.input[0].split('\n').shift() == op.output[0].split('\n').shift() &&
  293. lastNode &&
  294. lastOutput == op.input[0].split('\n').shift()) {
  295. lastNode.chain.push(node);
  296. } else {
  297. this._nodes.push(node);
  298. lastNode = null;
  299. lastOutput = null;
  300. if (op.output.length == 1) {
  301. lastNode = node;
  302. lastOutput = op.output[0].split('\n').shift();
  303. }
  304. }
  305. }
  306. this._inputs = [];
  307. for (const input of netDef.external_input) {
  308. if (netDef.external_input.length > 1 && initializers.has(input)) {
  309. continue;
  310. }
  311. this._inputs.push(new caffe2.Argument(input, [ arg(input) ]));
  312. }
  313. this._outputs = [];
  314. for (const output of netDef.external_output) {
  315. this._outputs.push(new caffe2.Argument(output, [ arg(output) ]));
  316. }
  317. }
  318. get name() {
  319. return this._name;
  320. }
  321. get type() {
  322. return this._type;
  323. }
  324. get inputs() {
  325. return this._inputs;
  326. }
  327. get outputs() {
  328. return this._outputs;
  329. }
  330. get nodes() {
  331. return this._nodes;
  332. }
  333. };
  334. caffe2.Argument = class {
  335. constructor(name, value) {
  336. this._name = name;
  337. this._value = value;
  338. }
  339. get name() {
  340. return this._name;
  341. }
  342. get value() {
  343. return this._value;
  344. }
  345. };
  346. caffe2.Value = class {
  347. constructor(name, type, initializer) {
  348. if (typeof name !== 'string') {
  349. throw new caffe2.Error("Invalid value identifier '" + JSON.stringify(name) + "'.");
  350. }
  351. this._name = name;
  352. this._type = type || null;
  353. this._initializer = initializer || null;
  354. }
  355. get name() {
  356. return this._name;
  357. }
  358. get type() {
  359. if (this._initializer) {
  360. return this._initializer.type;
  361. }
  362. return this._type;
  363. }
  364. get quantization() {
  365. if (this._initializer) {
  366. return this._initializer.quantization;
  367. }
  368. return null;
  369. }
  370. get initializer() {
  371. return this._initializer;
  372. }
  373. };
  374. caffe2.Node = class {
  375. constructor(metadata, op, arg) {
  376. this._name = op.name || '';
  377. this._device = op.engine || '';
  378. this._metadata = metadata;
  379. this._chain = [];
  380. this._type = metadata.type(op.type);
  381. this._attributes = op.arg.map((arg) => new caffe2.Attribute(metadata, this._type.name, arg));
  382. const inputs = op.input;
  383. const outputs = op.output;
  384. this._inputs = [];
  385. let inputIndex = 0;
  386. if (this._type && this._type.inputs) {
  387. for (const inputDef of this._type.inputs) {
  388. if (inputIndex < inputs.length || inputDef.option != 'optional') {
  389. const inputCount = (inputDef.option == 'variadic') ? (inputs.length - inputIndex) : 1;
  390. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount).filter((id) => id != '' || inputDef.option != 'optional').map((id) => arg(id));
  391. this._inputs.push(new caffe2.Argument(inputDef.name, inputArguments));
  392. inputIndex += inputCount;
  393. }
  394. }
  395. } else {
  396. this._inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  397. const inputName = ((inputIndex + index) == 0) ? 'input' : (inputIndex + index).toString();
  398. return new caffe2.Argument(inputName, [ arg(input) ]);
  399. }));
  400. }
  401. this._outputs = [];
  402. let outputIndex = 0;
  403. if (this._type && this._type.outputs) {
  404. for (const outputDef of this._type.outputs) {
  405. if (outputIndex < outputs.length || outputDef.option != 'optional') {
  406. const outputCount = (outputDef.option == 'variadic') ? (outputs.length - outputIndex) : 1;
  407. const outputArguments = outputs.slice(outputIndex, outputIndex + outputCount).map((id) => arg(id));
  408. this._outputs.push(new caffe2.Argument(outputDef.name, outputArguments));
  409. outputIndex += outputCount;
  410. }
  411. }
  412. } else {
  413. this._outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  414. const outputName = ((outputIndex + index) == 0) ? 'output' : (outputIndex + index).toString();
  415. return new caffe2.Argument(outputName, [ arg(output) ]);
  416. }));
  417. }
  418. }
  419. get name() {
  420. return this._name || '';
  421. }
  422. get device() {
  423. return this._device || '';
  424. }
  425. get type() {
  426. return this._type;
  427. }
  428. get inputs() {
  429. return this._inputs;
  430. }
  431. get outputs() {
  432. return this._outputs;
  433. }
  434. get attributes() {
  435. return this._attributes;
  436. }
  437. get chain() {
  438. return this._chain;
  439. }
  440. };
  441. caffe2.Attribute = class {
  442. constructor(metadata, type, arg) {
  443. this._name = arg.name;
  444. if (arg.floats && arg.floats.length > 0) {
  445. this._value = arg.floats;
  446. } else if (arg.ints && arg.ints.length > 0) {
  447. this._value = arg.ints;
  448. } else if (arg.nets && arg.nets.length > 0) {
  449. this._value = arg.nets.map((net) => new caffe2.Graph(metadata, net, null));
  450. this._type = 'graph[]';
  451. } else if (arg.n) {
  452. this._value = new caffe2.Graph(metadata, arg.n, null);
  453. this._type = 'graph';
  454. } else if (arg.i != 0) {
  455. this._value = arg.i;
  456. } else {
  457. this._value = arg.i;
  458. }
  459. metadata = metadata.attribute(type, arg.name);
  460. if (metadata) {
  461. if (Object.prototype.hasOwnProperty.call(metadata, 'type')) {
  462. this._type = metadata.type;
  463. if (this._type == 'boolean') {
  464. this._value = this._value !== 0 && this._value.toString() !== '0' ? true : false;
  465. }
  466. }
  467. }
  468. if (metadata) {
  469. if (Object.prototype.hasOwnProperty.call(metadata, 'visible') && !metadata.visible) {
  470. this._visible = false;
  471. } else if (metadata.default !== undefined) {
  472. if (this._value == metadata.default || (this._value && this._value.toString() == metadata.default.toString())) {
  473. this._visible = false;
  474. }
  475. }
  476. }
  477. }
  478. get name() {
  479. return this._name;
  480. }
  481. get type() {
  482. return this._type || null;
  483. }
  484. get value() {
  485. return this._value;
  486. }
  487. get visible() {
  488. return this._visible == false ? false : true;
  489. }
  490. };
  491. caffe2.Tensor = class {
  492. constructor(name, tensor) {
  493. this._name = name;
  494. const shape = tensor.shape && tensor.shape.ints ? tensor.shape.ints : null;
  495. this._type = new caffe2.TensorType(tensor.dataType, new caffe2.TensorShape(shape));
  496. this._values = tensor.values || null;
  497. this._scale = tensor.Y_scale ? tensor.Y_scale.f : 0;
  498. this._zeroPoint = tensor.Y_zero_point ? tensor.Y_zero_point.i : 0;
  499. }
  500. get name() {
  501. return this._name;
  502. }
  503. get type() {
  504. return this._type;
  505. }
  506. get category() {
  507. return 'Initializer';
  508. }
  509. get quantization() {
  510. if (this._scale != 0 || this._zeroPoint != 0) {
  511. return this._scale.toString() + ' * ' + (this._zeroPoint == 0 ? 'q' : ('(q - ' + this._zeroPoint.toString() + ')'));
  512. }
  513. return null;
  514. }
  515. get layout() {
  516. return '|';
  517. }
  518. get values() {
  519. if (!this._values) {
  520. return null;
  521. }
  522. switch (this._type.dataType) {
  523. case 'float32': return this._values.floats;
  524. case 'boolean': return this._values.ints;
  525. case 'int8': return new Int8Array(this._values.s);
  526. case 'int32': return this._values.ints;
  527. default: return null;
  528. }
  529. }
  530. };
  531. caffe2.TensorType = class {
  532. constructor(dataType, shape) {
  533. this._dataType = dataType;
  534. this._shape = shape;
  535. }
  536. get dataType() {
  537. return this._dataType || '?';
  538. }
  539. get shape() {
  540. return this._shape;
  541. }
  542. toString() {
  543. return this.dataType + this._shape.toString();
  544. }
  545. };
  546. caffe2.TensorShape = class {
  547. constructor(dimensions) {
  548. this._dimensions = dimensions;
  549. }
  550. get dimensions() {
  551. return this._dimensions;
  552. }
  553. toString() {
  554. return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : '';
  555. }
  556. };
  557. caffe2.Error = class extends Error {
  558. constructor(message) {
  559. super(message);
  560. this.name = 'Error loading Caffe2 model.';
  561. }
  562. };
  563. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  564. module.exports.ModelFactory = caffe2.ModelFactory;
  565. }