caffe.js 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. const caffe = {};
  2. caffe.ModelFactory = class {
  3. async match(context) {
  4. const identifier = context.identifier;
  5. const extension = identifier.lastIndexOf('.') > 0 ? identifier.split('.').pop().toLowerCase() : '';
  6. if (extension === 'caffemodel') {
  7. return context.set('caffe.pb');
  8. }
  9. if (identifier === 'saved_model.pbtxt' || identifier === 'saved_model.prototxt' ||
  10. identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
  11. identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
  12. return null;
  13. }
  14. const tags = await context.tags('pbtxt');
  15. if (tags.has('layer') || tags.has('layers')) {
  16. return context.set('caffe.pbtxt');
  17. } else if (tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
  18. return context.set('caffe.pbtxt.solver');
  19. }
  20. return null;
  21. }
  22. async open(context) {
  23. caffe.proto = await context.require('./caffe-proto');
  24. caffe.proto = caffe.proto.caffe;
  25. const openModel = async (context, netParameter) => {
  26. const metadata = await context.metadata('caffe-metadata.json');
  27. return new caffe.Model(metadata, netParameter);
  28. };
  29. const openNetParameterText = async (context, identifier, content) => {
  30. let netParameter = null;
  31. try {
  32. const reader = await content.read('protobuf.text');
  33. reader.field = function(tag, message) {
  34. const type = message.constructor.name;
  35. if (tag.endsWith('_param') && (type === 'LayerParameter' || type === 'V1LayerParameter' || type === 'V0LayerParameter')) {
  36. message[tag] = caffe.ModelFactory._decodeText(reader);
  37. return;
  38. } else if (message.constructor.name.endsWith('Parameter') || message.constructor.name === 'ParamSpec') {
  39. if (message[tag]) {
  40. if (!Array.isArray(message[tag])) {
  41. message[tag] = [message[tag]];
  42. }
  43. message[tag].push(this.read());
  44. } else {
  45. message[tag] = this.read();
  46. }
  47. return;
  48. }
  49. throw new Error(`Unknown field '${tag}' ${this.location()}`);
  50. };
  51. reader.enum = function(type) {
  52. const token = this.token();
  53. this.next();
  54. this.semicolon();
  55. if (!Object.prototype.hasOwnProperty.call(type, token)) {
  56. const value = Number.parseInt(token, 10);
  57. if (!Number.isNaN(token - value)) {
  58. return value;
  59. }
  60. return token;
  61. }
  62. return type[token];
  63. };
  64. if (/MobileNetSSD_train_template.prototxt/.exec(identifier)) {
  65. reader.integer = function() {
  66. const token = this.token();
  67. const value = Number.parseInt(token, 10);
  68. this.next();
  69. this.semicolon();
  70. if (Number.isNaN(token - value)) {
  71. return token;
  72. }
  73. return value;
  74. };
  75. }
  76. netParameter = caffe.proto.NetParameter.decodeText(reader);
  77. } catch (error) {
  78. const message = error && error.message ? error.message : error.toString();
  79. throw new caffe.Error(`File text format is not caffe.NetParameter (${message.replace(/\.$/, '')}).`);
  80. }
  81. return openModel(context, netParameter);
  82. };
  83. switch (context.type) {
  84. case 'caffe.pbtxt.solver': {
  85. const reader = await context.read('protobuf.text');
  86. reader.field = function(tag, message) {
  87. if (message instanceof caffe.proto.SolverParameter) {
  88. message[tag] = this.read();
  89. return;
  90. }
  91. throw new Error(`Unknown field '${tag}'${this.location()}`);
  92. };
  93. const solver = caffe.proto.SolverParameter.decodeText(reader);
  94. if (solver.net_param) {
  95. return openModel(context, solver.net_param);
  96. }
  97. let name = solver.net || solver.train_net;
  98. name = name.split('/').pop();
  99. try {
  100. const content = await context.fetch(name);
  101. return await openNetParameterText(context, name, content);
  102. } catch (error) {
  103. const message = error.message ? error.message : error.toString();
  104. throw new caffe.Error(`Failed to load '${name}' (${message.replace(/\.$/, '')}).`);
  105. }
  106. }
  107. case 'caffe.pbtxt': {
  108. return await openNetParameterText(context, context.identifier, context);
  109. }
  110. case 'caffe.pb': {
  111. let netParameter = null;
  112. try {
  113. const reader = await context.read('protobuf.binary');
  114. netParameter = caffe.proto.NetParameter.decode(reader);
  115. } catch (error) {
  116. const message = error && error.message ? error.message : error.toString();
  117. throw new caffe.Error(`File format is not caffe.NetParameter (${message.replace(/\.$/, '')}).`);
  118. }
  119. return await openModel(context, netParameter);
  120. }
  121. default: {
  122. throw new caffe.Error(`Unsupported Caffe format '${context.type}'.`);
  123. }
  124. }
  125. }
  126. static _decodeText(reader) {
  127. const message = {};
  128. reader.start();
  129. while (!reader.end()) {
  130. const tag = reader.tag();
  131. const value = reader.read();
  132. if (message[tag]) {
  133. if (!Array.isArray(message[tag])) {
  134. message[tag] = [message[tag]];
  135. }
  136. message[tag].push(value);
  137. } else {
  138. message[tag] = value;
  139. }
  140. }
  141. return message;
  142. }
  143. };
  144. caffe.Model = class {
  145. constructor(metadata, net) {
  146. this.name = net.name;
  147. this.format = 'Caffe';
  148. this.modules = [];
  149. let version = -1;
  150. if (net.layers && net.layers.length > 0) {
  151. if (net.layers.every((layer) => Object.prototype.hasOwnProperty.call(layer, 'layer'))) {
  152. version = 0;
  153. net.layer = net.layers;
  154. } else {
  155. version = 1;
  156. net.layer = net.layers;
  157. }
  158. } else if (net.layer && net.layer.length > 0) {
  159. version = 2;
  160. }
  161. this.format = `Caffe v${version}`;
  162. const phases = new Set();
  163. for (const layer of net.layer) {
  164. for (const include of layer.include) {
  165. if (include.phase !== undefined) {
  166. phases.add(include.phase);
  167. }
  168. }
  169. }
  170. if (phases.size === 0) {
  171. phases.add(-1);
  172. }
  173. for (const phase of phases) {
  174. const graph = new caffe.Graph(metadata, phase, net, version);
  175. this.modules.push(graph);
  176. }
  177. }
  178. };
  179. caffe.Graph = class {
  180. constructor(metadata, phase, net, version) {
  181. switch (phase) {
  182. case 0: this.name = 'TRAIN'; break;
  183. case 1: this.name = 'TEST'; break;
  184. case -1: this.name = ''; break;
  185. default: this.name = phase.toString(); break;
  186. }
  187. this.nodes = [];
  188. this.inputs = [];
  189. this.outputs = [];
  190. for (const layer of net.layer) {
  191. layer.input = layer.bottom.slice(0);
  192. layer.output = layer.top.slice(0);
  193. layer.chain = [];
  194. }
  195. const layers = [];
  196. for (const layer of net.layer) {
  197. if (phase === -1 || layer.include.every((include) => include.phase === phase)) {
  198. layers.push(layer);
  199. }
  200. }
  201. const scopes = new Map();
  202. for (let i = 0; i < layers.length; i++) {
  203. const layer = layers[i];
  204. layer.input = layer.input.map((input) => scopes.has(input) ? scopes.get(input) : input);
  205. layer.output = layer.output.map((output) => {
  206. const value = scopes.has(output) ? `${output}\n${i}` : output;
  207. scopes.set(output, value);
  208. return value;
  209. });
  210. }
  211. // Graph Inputs
  212. const usedOutputs = new Set();
  213. for (const layer of layers) {
  214. for (const output of layer.output) {
  215. usedOutputs.add(output);
  216. }
  217. }
  218. const unusedInputs = [];
  219. for (const layer of layers) {
  220. for (const input of layer.input) {
  221. if (!usedOutputs.has(input)) {
  222. unusedInputs.push(input);
  223. }
  224. }
  225. }
  226. const values = new Map();
  227. const value = (name, type) => {
  228. if (!values.has(name)) {
  229. values.set(name, new caffe.Value(name, type));
  230. } else if (type) {
  231. throw new caffe.Error(`Duplicate value '${name}'.`);
  232. }
  233. return values.get(name);
  234. };
  235. const nodes = [];
  236. let lastLayer = null;
  237. let lastTop = null;
  238. while (layers.length > 0) {
  239. let layer = layers.shift();
  240. if (layer.output.length === 1 && layer.input.length === 1 &&
  241. layer.output[0].split('\n').shift() === layer.input[0].split('\n').shift() &&
  242. lastLayer &&
  243. lastTop === layer.output[0].split('\n').shift()) {
  244. lastLayer.chain = lastLayer.chain || [];
  245. lastLayer.chain.push(layer);
  246. } else {
  247. if (layer.type === 'Input' && layer.input.length === 0) {
  248. for (let i = 0; i < layer.output.length; i++) {
  249. const output = layer.output[i];
  250. const dim = layer.input_param && layer.input_param.shape && i < layer.input_param.shape.length ? layer.input_param.shape[i].dim : null;
  251. const shape = dim ? new caffe.TensorShape(dim.map((dim) => dim.toNumber())) : null;
  252. const type = shape ? new caffe.TensorType(null, shape) : null;
  253. const argument = new caffe.Argument(output, [value(output, type)]);
  254. this.inputs.push(argument);
  255. }
  256. layer = null;
  257. }
  258. if (layer) {
  259. nodes.push(layer);
  260. lastLayer = null;
  261. lastTop = null;
  262. if (layer.output.length === 1) {
  263. lastLayer = layer;
  264. lastTop = layer.output[0].split('\n').shift();
  265. }
  266. }
  267. }
  268. }
  269. if (net.input) {
  270. for (let i = 0; i < net.input.length; i++) {
  271. const input = net.input[i];
  272. if (this.inputs.some((item) => item.name === input)) {
  273. continue;
  274. }
  275. let inputType = null;
  276. if (net.input_shape && i < net.input_shape.length) {
  277. const blobShape = net.input_shape[i];
  278. if (blobShape && blobShape.dim) {
  279. const shape = new caffe.TensorShape(blobShape.dim.map((dim) => dim.toNumber()));
  280. inputType = new caffe.TensorType(null, shape);
  281. }
  282. }
  283. const dim = i * 4;
  284. if (!inputType && net.input_dim && net.input_dim.length >= dim) {
  285. const shape = new caffe.TensorShape(net.input_dim.slice(dim, dim + 4));
  286. inputType = new caffe.TensorType(null, shape);
  287. }
  288. this.inputs.push(new caffe.Argument(input, [value(input, inputType, null)]));
  289. }
  290. }
  291. for (const layer of nodes) {
  292. const node = new caffe.Node(metadata, layer, version, value);
  293. if (layer.chain && layer.chain.length > 0) {
  294. for (const chain of layer.chain) {
  295. node.chain.push(new caffe.Node(metadata, chain, version, value));
  296. }
  297. }
  298. this.nodes.push(node);
  299. }
  300. if (this.inputs.length === 0 && unusedInputs.length === 1) {
  301. this.inputs.push(new caffe.Argument(unusedInputs[0], [value(unusedInputs[0], null)]));
  302. }
  303. }
  304. };
  305. caffe.Argument = class {
  306. constructor(name, value, type = null, visible = true) {
  307. this.name = name;
  308. this.value = value;
  309. this.type = type;
  310. this.visible = visible;
  311. }
  312. };
  313. caffe.Value = class {
  314. constructor(name, type = null, initializer = null) {
  315. if (typeof name !== 'string') {
  316. throw new caffe.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  317. }
  318. this.name = name;
  319. this.type = type;
  320. this.initializer = initializer;
  321. }
  322. };
  323. caffe.Node = class {
  324. constructor(metadata, layer, version, value) {
  325. this.attributes = [];
  326. this.chain = [];
  327. let type = '';
  328. switch (version) {
  329. case 0: {
  330. this.name = layer.layer.name;
  331. type = layer.layer.type;
  332. break;
  333. }
  334. case 1: {
  335. this.name = layer.name;
  336. type = caffe.Utility.layerType(layer.type);
  337. break;
  338. }
  339. case 2: {
  340. this.name = layer.name;
  341. type = layer.type;
  342. break;
  343. }
  344. default: {
  345. throw new caffe.Error(`Unsupported Caffe version '${version}'.`);
  346. }
  347. }
  348. this.type = metadata.type(type) || { name: type };
  349. let initializers = [];
  350. const attributes = [];
  351. switch (version) {
  352. case 0: {
  353. for (const name of Object.keys(layer.layer)) {
  354. if (name !== 'type' && name !== 'name' && name !== 'blobs' && name !== 'blobs_lr') {
  355. const value = layer.layer[name];
  356. const schema = metadata.attribute(type, name);
  357. attributes.push([schema, name, value]);
  358. }
  359. }
  360. initializers = layer.layer.blobs.map((blob) => new caffe.Tensor(blob));
  361. break;
  362. }
  363. case 1:
  364. case 2: {
  365. for (const layer_kind of Object.keys(layer)) {
  366. if (layer_kind.endsWith('_param') || layer_kind === 'transform_param') {
  367. const param = layer[layer_kind];
  368. if (type === 'Deconvolution') {
  369. type = 'Convolution';
  370. }
  371. const prototype = Object.getPrototypeOf(param);
  372. for (const name of Object.keys(param)) {
  373. const defaultValue = prototype[name];
  374. const value = param[name];
  375. const schema = metadata.attribute(type, name);
  376. attributes.push([schema, name, value, defaultValue]);
  377. }
  378. }
  379. }
  380. if (layer.include && layer.include.length > 0) {
  381. const schema = metadata.attribute(type, 'include');
  382. attributes.push([schema, 'include', layer.include]);
  383. }
  384. if (layer.exclude && layer.exclude.length > 0) {
  385. const schema = metadata.attribute(type, 'exclude');
  386. attributes.push([schema, 'exclude', layer.exclude]);
  387. }
  388. if (this.type === 'Data' && layer.input_param && layer.input_param.shape) {
  389. const schema = metadata.attribute(type, 'shape');
  390. attributes.push([schema, 'shape', layer.input_param.shape]);
  391. }
  392. initializers = layer.blobs.map((blob) => new caffe.Tensor(blob));
  393. break;
  394. }
  395. default: {
  396. throw new caffe.Error(`Unsupported Caffe version '${version}'.`);
  397. }
  398. }
  399. this.inputs = [];
  400. const inputs = layer.input.concat(initializers);
  401. let inputIndex = 0;
  402. if (this.type && this.type.inputs) {
  403. for (const inputDef of this.type.inputs) {
  404. if (inputIndex < inputs.length || inputDef.option !== 'optional') {
  405. const count = inputDef.option === 'variadic' ? inputs.length - inputIndex : 1;
  406. const values = inputs.slice(inputIndex, inputIndex + count).filter((input) => input !== '' || inputDef.option !== 'optional').map((input) => {
  407. return input instanceof caffe.Tensor ? new caffe.Value('', input.type, input) : value(input, null, null);
  408. });
  409. const argument = new caffe.Argument(inputDef.name, values);
  410. this.inputs.push(argument);
  411. inputIndex += count;
  412. }
  413. }
  414. }
  415. this.inputs.push(...inputs.slice(inputIndex).map((input) => {
  416. return new caffe.Argument(inputIndex.toString(), [
  417. input instanceof caffe.Tensor ? new caffe.Value('', input.type, input) : value(input, null, null)
  418. ]);
  419. }));
  420. this.outputs = [];
  421. const outputs = layer.output;
  422. let outputIndex = 0;
  423. if (this.type && this.type.outputs) {
  424. for (const outputDef of this.type.outputs) {
  425. if (outputIndex < outputs.length) {
  426. const count = (outputDef.option === 'variadic') ? (outputs.length - outputIndex) : 1;
  427. const values = outputs.slice(outputIndex, outputIndex + count).map((output) => value(output, null, null));
  428. const argument = new caffe.Argument(outputDef.name, values);
  429. this.outputs.push(argument);
  430. outputIndex += count;
  431. }
  432. }
  433. }
  434. this.outputs.push(...outputs.slice(outputIndex).map((output, index) => {
  435. return new caffe.Argument((outputIndex + index).toString(), [value(output, null, null)]);
  436. }));
  437. this.attributes = attributes.map(([metadata, name, value, defaultValue]) => {
  438. let visible = true;
  439. let type = null;
  440. if (metadata && metadata.type) {
  441. type = metadata.type;
  442. }
  443. if (value instanceof caffe.proto.BlobShape) {
  444. value = new caffe.TensorShape(value.dim.map((dim) => dim.toNumber()));
  445. type = 'shape';
  446. }
  447. if (metadata && metadata.visible === false) {
  448. visible = false;
  449. }
  450. if (metadata && metadata.default !== undefined) {
  451. defaultValue = metadata.default;
  452. }
  453. if (defaultValue !== undefined) {
  454. if (value === defaultValue) {
  455. visible = false;
  456. } else if (Array.isArray(value) && Array.isArray(defaultValue)) {
  457. if (value.length === defaultValue.length && value.every((item, index) => item === defaultValue[index])) {
  458. visible = false;
  459. }
  460. }
  461. }
  462. value = type ? caffe.Utility.enum(type, value) : value;
  463. return new caffe.Argument(name, value, type, visible);
  464. });
  465. }
  466. };
  467. caffe.Tensor = class {
  468. constructor(blob) {
  469. let shape = [];
  470. if (Object.prototype.hasOwnProperty.call(blob, 'num') &&
  471. Object.prototype.hasOwnProperty.call(blob, 'channels') &&
  472. Object.prototype.hasOwnProperty.call(blob, 'width') &&
  473. Object.prototype.hasOwnProperty.call(blob, 'height')) {
  474. if (blob.num !== 1) {
  475. shape.push(blob.num);
  476. }
  477. if (blob.channels !== 1) {
  478. shape.push(blob.channels);
  479. }
  480. if (blob.height !== 1) {
  481. shape.push(blob.height);
  482. }
  483. if (blob.width !== 1) {
  484. shape.push(blob.width);
  485. }
  486. } else if (Object.prototype.hasOwnProperty.call(blob, 'shape')) {
  487. shape = blob.shape.dim.map((dim) => Number(dim));
  488. }
  489. let dataType = '?';
  490. if (blob.data.length > 0) {
  491. dataType = 'float32';
  492. this.values = blob.data;
  493. } else if (blob.double_data.length > 0) {
  494. dataType = 'float64';
  495. this.values = blob.double_data;
  496. }
  497. this.category = 'Blob';
  498. this.encoding = '|';
  499. this.type = new caffe.TensorType(dataType, new caffe.TensorShape(shape));
  500. }
  501. };
  502. caffe.TensorType = class {
  503. constructor(dataType, shape) {
  504. this.dataType = dataType;
  505. this.shape = shape;
  506. }
  507. toString() {
  508. return (this.dataType || '?') + this.shape.toString();
  509. }
  510. };
  511. caffe.TensorShape = class {
  512. constructor(dimensions) {
  513. this.dimensions = dimensions;
  514. }
  515. toString() {
  516. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`) : '';
  517. }
  518. };
  519. caffe.Utility = class {
  520. static layerType(type) {
  521. type = type || 0;
  522. if (!caffe.Utility._layerTypeMap) {
  523. caffe.Utility._layerTypeMap = new Map();
  524. const known = { 'BNLL': 'BNLL', 'HDF5': 'HDF5', 'LRN': 'LRN', 'RELU': 'ReLU', 'TANH': 'TanH', 'ARGMAX': 'ArgMax', 'MVN': 'MVN', 'ABSVAL': 'AbsVal' };
  525. for (const key of Object.keys(caffe.proto.V1LayerParameter.LayerType)) {
  526. const value = caffe.proto.V1LayerParameter.LayerType[key];
  527. caffe.Utility._layerTypeMap.set(value, key.split('_').map((item) => known[item] || item.substring(0, 1) + item.substring(1).toLowerCase()).join(''));
  528. }
  529. }
  530. return caffe.Utility._layerTypeMap.has(type) ? caffe.Utility._layerTypeMap.get(type) : type.toString();
  531. }
  532. static enum(name, value) {
  533. let type = caffe.proto;
  534. const parts = name.split('.');
  535. while (type && parts.length > 0) {
  536. type = type[parts.shift()];
  537. }
  538. if (type) {
  539. caffe.Utility._enumKeyMap = caffe.Utility._enumKeyMap || new Map();
  540. if (!caffe.Utility._enumKeyMap.has(name)) {
  541. const map = new Map(Object.entries(type).map(([name, value]) => [value, name]));
  542. caffe.Utility._enumKeyMap.set(name, map);
  543. }
  544. const map = caffe.Utility._enumKeyMap.get(name);
  545. if (map.has(value)) {
  546. return map.get(value);
  547. }
  548. }
  549. return value;
  550. }
  551. };
  552. caffe.Error = class extends Error {
  553. constructor(message) {
  554. super(message);
  555. this.name = 'Error loading Caffe model.';
  556. }
  557. };
  558. export const ModelFactory = caffe.ModelFactory;