caffe2.js 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. const caffe2 = {};
  2. caffe2.ModelFactory = class {
  3. async match(context) {
  4. const identifier = context.identifier.toLowerCase();
  5. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  6. switch (extension) {
  7. case 'pbtxt':
  8. case 'prototxt': {
  9. const tags = await context.tags('pbtxt');
  10. if (tags.has('op') && !tags.has('op.attr') && !tags.has('op.graph_op_name') && !tags.has('op.endpoint')) {
  11. return context.set('caffe2.pbtxt');
  12. }
  13. break;
  14. }
  15. case 'pb': {
  16. const tags = await context.tags('pb');
  17. if (tags.size > 0 &&
  18. Array.from(tags.keys()).every((tag) => tag <= 9) &&
  19. Array.from(tags.values()).every((type) => type <= 4)) {
  20. if (tags.size === 1 && tags.get(2) === 2 && identifier.endsWith('saved_model.pb')) {
  21. return null;
  22. }
  23. const schema = [[1,2],[2,2],[3,2],[4,0],[5,2],[6,2],[7,2],[8,2],[9,2]];
  24. if (schema.every(([key, value]) => !tags.has(key) || tags.get(key) === value)) {
  25. const stream = context.stream;
  26. if (stream.length > 3) {
  27. const buffer = stream.peek(Math.min(stream.length, 67));
  28. const [signature, size] = buffer;
  29. switch (signature) {
  30. case 0x0A:
  31. if (size < 64 &&
  32. buffer.length > 2 + size + 1 &&
  33. buffer.slice(2, 2 + size).every((c) => c >= 32 && c <= 127) &&
  34. buffer[2 + size] === 0x12) {
  35. return context.set('caffe2.pb');
  36. }
  37. break;
  38. case 0x12:
  39. return context.set('caffe2.pb');
  40. default:
  41. break;
  42. }
  43. }
  44. }
  45. }
  46. break;
  47. }
  48. default: {
  49. break;
  50. }
  51. }
  52. return null;
  53. }
  54. async open(context) {
  55. caffe2.proto = await context.require('./caffe2-proto');
  56. caffe2.proto = caffe2.proto.caffe2;
  57. const metadata = await context.metadata('caffe2-metadata.json');
  58. const identifier = context.identifier;
  59. const parts = identifier.split('.');
  60. const extension = parts.pop().toLowerCase();
  61. const base = parts.join('.');
  62. let predict = null;
  63. let init = null;
  64. switch (context.type) {
  65. case 'caffe2.pbtxt': {
  66. if (base.toLowerCase().endsWith('init_net') || base.toLowerCase().startsWith('init_net')) {
  67. init = context;
  68. try {
  69. const name = identifier.replace('init_net', 'predict_net');
  70. predict = await context.fetch(name);
  71. predict.set(context.type);
  72. } catch {
  73. // continue regardless of error
  74. }
  75. } else if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
  76. predict = context;
  77. const name = identifier.replace('predict_net', 'init_net');
  78. try {
  79. init = await context.fetch(name.replace(/\.pbtxt/, '.pb'));
  80. init.set('caffe2.pb');
  81. } catch {
  82. try {
  83. init = await context.fetch(name);
  84. init.set('caffe2.pbtxt');
  85. } catch {
  86. // continue regardless of error
  87. }
  88. }
  89. } else {
  90. predict = context;
  91. try {
  92. init = await context.fetch(`${base}_init.pb`);
  93. init.set('caffe2.pb');
  94. } catch {
  95. // continue regardless of error
  96. }
  97. }
  98. break;
  99. }
  100. case 'caffe2.pb': {
  101. if (base.toLowerCase().endsWith('init_net')) {
  102. init = context;
  103. const extensions = new Set([extension, 'pb', 'pbtxt']);
  104. for (const extension of extensions) {
  105. try {
  106. const name = `${base.replace(/init_net$/, '')}predict_net.${extension}`;
  107. /* eslint-disable no-await-in-loop */
  108. predict = await context.fetch(name);
  109. /* eslint-enable no-await-in-loop */
  110. predict.set(`caffe2.${extension}`);
  111. break;
  112. } catch {
  113. // continue regardless of error
  114. }
  115. }
  116. } else if (base.toLowerCase().endsWith('_init')) {
  117. try {
  118. const name = `${base.replace(/_init$/, '')}.${extension}`;
  119. predict = await context.fetch(name);
  120. predict.set(context.type);
  121. } catch {
  122. // continue regardless of error
  123. }
  124. } else if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
  125. predict = context;
  126. try {
  127. const name = identifier.replace('predict_net', 'init_net');
  128. init = await context.fetch(name);
  129. init.set(context.type);
  130. } catch {
  131. // continue regardless of error
  132. }
  133. } else {
  134. predict = context;
  135. try {
  136. const file = `${base}_init.${extension}`;
  137. init = await context.fetch(file, null);
  138. init.set(context.type);
  139. } catch {
  140. // continue regardless of error
  141. }
  142. }
  143. break;
  144. }
  145. default: {
  146. throw new caffe2.Error(`Unsupported Caffe2 format '${context.type}'.`);
  147. }
  148. }
  149. if (!predict && !init) {
  150. throw new caffe2.Error(`Caffe2 model does not contain predict or init data.`);
  151. }
  152. const open = async (context) => {
  153. if (context) {
  154. switch (context.type) {
  155. case 'caffe2.pb':
  156. try {
  157. const reader = await context.read('protobuf.binary');
  158. return caffe2.proto.NetDef.decode(reader);
  159. } catch (error) {
  160. const message = error && error.message ? error.message : error.toString();
  161. throw new caffe2.Error(`File format is not caffe2.NetDef (${message.replace(/\.$/, '')}).`);
  162. }
  163. case 'caffe2.pbtxt':
  164. try {
  165. const reader = await context.read('protobuf.text');
  166. reader.field = function(tag, message) {
  167. if (message instanceof caffe2.proto.DeviceOption) {
  168. message[tag] = this.read();
  169. return;
  170. }
  171. throw new Error(`Unknown field '${tag}' ${this.location()}`);
  172. };
  173. return caffe2.proto.NetDef.decodeText(reader);
  174. } catch (error) {
  175. const message = error && error.message ? error.message : error.toString();
  176. throw new caffe2.Error(`File format is not caffe2.NetDef (${message.replace(/\.$/, '')}).`);
  177. }
  178. default:
  179. throw new caffe2.Error(`Unsupported Caffe2 predict format '${context.type}'.`);
  180. }
  181. }
  182. return null;
  183. };
  184. const predict_net = await open(predict);
  185. const init_net = await open(init);
  186. return new caffe2.Model(metadata, predict_net, init_net);
  187. }
  188. };
  189. caffe2.Model = class {
  190. constructor(metadata, predict_net, init_net) {
  191. const net = predict_net || init_net;
  192. this.format = 'Caffe2';
  193. this.domain = net.domain || null;
  194. this.modules = [new caffe2.Graph(metadata, predict_net, init_net)];
  195. }
  196. };
  197. caffe2.Graph = class {
  198. constructor(metadata, predict_net, init_net) {
  199. const net = predict_net || init_net;
  200. init_net = predict_net ? init_net : null;
  201. this.name = net.name || '';
  202. this.nodes = [];
  203. this.description = net.type;
  204. const initializers = new Set();
  205. const tensors = new Map();
  206. for (const name of net.external_input) {
  207. tensors.set(name, new caffe2.Tensor(name, {}));
  208. }
  209. if (init_net) {
  210. const dataTypes = new Map([
  211. ['GivenTensorFill', 'float32'],
  212. ['GivenTensorDoubleFill', 'float64'],
  213. ['GivenTensorBoolFill', 'boolean'],
  214. ['GivenTensorByteStringToUInt8Fill', 'uint8'],
  215. ['GivenTensorInt16Fill', 'int16'],
  216. ['GivenTensorSInt16Fill', 'int16'],
  217. ['GivenTensorIntFill', 'int32'],
  218. ['GivenTensorInt64Fill', 'int64'],
  219. ['GivenTensorStringFill', 'string'],
  220. ['Int8GivenIntTensorFill', 'int32'],
  221. ['Int8GivenTensorFill', 'int8'],
  222. ['XavierFill', null],
  223. ['ConstantFill', null]
  224. ]);
  225. for (const op of init_net.op) {
  226. if (op.output && op.output.length === 1) {
  227. const [name] = op.output;
  228. const tensor = {};
  229. for (const arg of op.arg) {
  230. tensor[arg.name] = arg;
  231. }
  232. if (!dataTypes.has(op.type)) {
  233. throw new caffe2.Error(`Unsupported init op '${op.type}'.`);
  234. }
  235. tensor.dataType = dataTypes.get(op.type);
  236. if (tensor.values && tensor.values.floats && (tensor.values.floats.length !== 1 || tensor.values.floats[0] !== 0)) {
  237. initializers.add(name);
  238. }
  239. tensors.set(name, new caffe2.Tensor(name, tensor));
  240. }
  241. }
  242. }
  243. const scope = {};
  244. for (let i = 0; i < net.op.length; i++) {
  245. const op = net.op[i];
  246. op.input = op.input.map((input) => scope[input] ? scope[input] : input);
  247. op.output = op.output.map((output) => {
  248. if (scope[output]) {
  249. const next = `${output}\n${i}`; // custom argument id
  250. scope[output] = next;
  251. return next;
  252. }
  253. scope[output] = output;
  254. return output;
  255. });
  256. }
  257. const values = new Map();
  258. values.map = (name, type, tensor) => {
  259. if (!values.has(name)) {
  260. values.set(name, new caffe2.Value(name, type || null, tensor || null));
  261. } else if (type || tensor) {
  262. throw new caffe2.Value(`Duplicate value '${name}'.`);
  263. }
  264. return values.get(name);
  265. };
  266. for (const op of net.op) {
  267. let index = 0;
  268. for (const name of op.input) {
  269. if (index > 0 && tensors.has(name)) {
  270. if (!values.has(name)) {
  271. values.set(name, new caffe2.Value(name, null, tensors.get(name)));
  272. }
  273. initializers.add(name);
  274. }
  275. index++;
  276. }
  277. }
  278. for (const op of net.op) {
  279. for (const name of op.output) {
  280. if (tensors.has(name)) {
  281. initializers.add(name);
  282. }
  283. }
  284. }
  285. let lastNode = null;
  286. let lastOutput = null;
  287. for (const op of net.op) {
  288. const node = new caffe2.Node(metadata, op, values);
  289. if (op.input.length === 1 &&
  290. op.output.length >= 1 &&
  291. op.input[0].split('\n').shift() === op.output[0].split('\n').shift() &&
  292. lastNode &&
  293. lastOutput === op.input[0].split('\n').shift()) {
  294. lastNode.chain.push(node);
  295. } else {
  296. this.nodes.push(node);
  297. lastNode = null;
  298. lastOutput = null;
  299. if (op.output.length === 1) {
  300. lastNode = node;
  301. lastOutput = op.output[0].split('\n').shift();
  302. }
  303. }
  304. }
  305. this.inputs = [];
  306. for (const input of net.external_input) {
  307. if (net.external_input.length > 1 && initializers.has(input)) {
  308. continue;
  309. }
  310. const argument = new caffe2.Argument(input, [values.map(input)]);
  311. this.inputs.push(argument);
  312. }
  313. this.outputs = [];
  314. for (const output of net.external_output) {
  315. const argument = new caffe2.Argument(output, [values.map(output)]);
  316. this.outputs.push(argument);
  317. }
  318. }
  319. };
  320. caffe2.Argument = class {
  321. constructor(name, value, type = null, visible = true) {
  322. this.name = name;
  323. this.value = value;
  324. this.type = type;
  325. this.visible = visible;
  326. }
  327. };
  328. caffe2.Value = class {
  329. constructor(name, type, initializer = null) {
  330. if (typeof name !== 'string') {
  331. throw new caffe2.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  332. }
  333. this.name = name;
  334. this.type = !type && initializer ? initializer.type : type;
  335. this.quantization = initializer && initializer.quantization ? initializer.quantization : null;
  336. this.initializer = initializer;
  337. }
  338. };
  339. caffe2.Node = class {
  340. constructor(metadata, op, values) {
  341. this.name = op.name || '';
  342. this.device = op.engine || '';
  343. this.chain = [];
  344. this.type = metadata.type(op.type);
  345. this.attributes = op.arg.map((arg) => {
  346. const schema = metadata.attribute(op.type, arg.name);
  347. const name = arg.name;
  348. let value = null;
  349. let type = null;
  350. let visible = true;
  351. if (arg.floats && arg.floats.length > 0) {
  352. value = arg.floats;
  353. } else if (arg.ints && arg.ints.length > 0) {
  354. value = arg.ints;
  355. } else if (arg.nets && arg.nets.length > 0) {
  356. value = arg.nets.map((net) => new caffe2.Graph(metadata, net, null));
  357. type = 'graph[]';
  358. } else if (arg.n) {
  359. value = new caffe2.Graph(metadata, arg.n, null);
  360. type = 'graph';
  361. } else {
  362. value = arg.i;
  363. }
  364. if (schema) {
  365. type = !type && schema.type ? schema.type : type;
  366. if (type === 'boolean') {
  367. value = value !== 0 && value.toString() !== '0' ? true : false;
  368. }
  369. if (schema.visible === false) {
  370. visible = false;
  371. } else if (schema.default !== undefined) {
  372. if (value === metadata.default || (value && value.toString() === schema.default.toString())) {
  373. visible = false;
  374. }
  375. }
  376. }
  377. return new caffe2.Argument(name, value, type, visible);
  378. });
  379. const inputs = op.input;
  380. const outputs = op.output;
  381. this.inputs = [];
  382. let inputIndex = 0;
  383. if (this.type && this.type.inputs) {
  384. for (const inputDef of this.type.inputs) {
  385. if (inputIndex < inputs.length || inputDef.option !== 'optional') {
  386. const inputCount = (inputDef.option === 'variadic') ? (inputs.length - inputIndex) : 1;
  387. const inputArguments = inputs.slice(inputIndex, inputIndex + inputCount).filter((id) => id !== '' || inputDef.option !== 'optional').map((id) => values.map(id));
  388. this.inputs.push(new caffe2.Argument(inputDef.name, inputArguments));
  389. inputIndex += inputCount;
  390. }
  391. }
  392. } else {
  393. this.inputs.push(...inputs.slice(inputIndex).map((input, index) => {
  394. const inputName = ((inputIndex + index) === 0) ? 'input' : (inputIndex + index).toString();
  395. return new caffe2.Argument(inputName, [values.map(input)]);
  396. }));
  397. }
  398. this.outputs = [];
  399. let outputIndex = 0;
  400. if (this.type && this.type.outputs) {
  401. for (const outputDef of this.type.outputs) {
  402. if (outputIndex < outputs.length || outputDef.option !== 'optional') {
  403. const outputCount = (outputDef.option === 'variadic') ? (outputs.length - outputIndex) : 1;
  404. const outputArguments = outputs.slice(outputIndex, outputIndex + outputCount).map((id) => values.map(id));
  405. this.outputs.push(new caffe2.Argument(outputDef.name, outputArguments));
  406. outputIndex += outputCount;
  407. }
  408. }
  409. } else {
  410. this.outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  411. const outputName = ((outputIndex + index) === 0) ? 'output' : (outputIndex + index).toString();
  412. return new caffe2.Argument(outputName, [values.map(output)]);
  413. }));
  414. }
  415. }
  416. };
  417. caffe2.Tensor = class {
  418. constructor(name, tensor) {
  419. this.name = name;
  420. const shape = tensor.shape && tensor.shape.ints ? tensor.shape.ints : null;
  421. this.type = new caffe2.TensorType(tensor.dataType, new caffe2.TensorShape(shape));
  422. this.values = null;
  423. this.category = 'Initializer';
  424. this.encoding = '|';
  425. if (tensor.Y_scale !== undefined || tensor.Y_zero_point !== undefined) {
  426. this.quantization = {
  427. type: 'linear',
  428. scale: [tensor.Y_scale ? tensor.Y_scale.f : 0],
  429. offset: [tensor.Y_zero_point && typeof tensor.Y_zero_point.i === 'bigint' ? tensor.Y_zero_point.i.toNumber() : 0]
  430. };
  431. }
  432. if (tensor.values) {
  433. switch (this.type.dataType) {
  434. case 'float32': this.values = tensor.values.floats; break;
  435. case 'boolean': this.values = tensor.values.ints; break;
  436. case 'int8': this.values = new Int8Array(tensor.values.s); break;
  437. case 'int32': this.values = tensor.values.ints; break;
  438. default: break;
  439. }
  440. }
  441. }
  442. };
  443. caffe2.TensorType = class {
  444. constructor(dataType, shape) {
  445. this.dataType = dataType || '?';
  446. this.shape = shape;
  447. }
  448. toString() {
  449. return this.dataType + this.shape.toString();
  450. }
  451. };
  452. caffe2.TensorShape = class {
  453. constructor(dimensions) {
  454. this.dimensions = Array.isArray(dimensions) ? dimensions.map((dim) => typeof dim === 'bigint' ? dim.toNumber() : dim) : dimensions;
  455. }
  456. toString() {
  457. if (Array.isArray(this.dimensions) && this.dimensions.length > 0) {
  458. return `[${this.dimensions.map((dim) => dim.toString()).join(',')}]`;
  459. }
  460. return '';
  461. }
  462. };
  463. caffe2.Error = class extends Error {
  464. constructor(message) {
  465. super(message);
  466. this.name = 'Error loading Caffe2 model.';
  467. }
  468. };
  469. export const ModelFactory = caffe2.ModelFactory;