barracuda.js 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. // Experimental
  2. const barracuda = {};
  3. barracuda.ModelFactory = class {
  4. async match(context) {
  5. const stream = context.stream;
  6. if (stream && stream.length > 12) {
  7. const buffer = stream.peek(12);
  8. if (buffer[0] <= 0x20 && buffer.subarray(1, 8).every((value) => value === 0x00)) {
  9. return context.set('barracuda');
  10. }
  11. }
  12. return null;
  13. }
  14. async open(context) {
  15. const metadata = barracuda.Metadata.open();
  16. const reader = await context.read('binary');
  17. const model = new barracuda.NNModel(reader);
  18. return new barracuda.Model(metadata, model);
  19. }
  20. };
  21. barracuda.Model = class {
  22. constructor(metadata, model) {
  23. const version = model.version.toString();
  24. this.format = `Barracuda v${version}`;
  25. this.modules = [new barracuda.Graph(metadata, model)];
  26. }
  27. };
  28. barracuda.Graph = class {
  29. constructor(metadata, model) {
  30. this.name = '';
  31. this.inputs = [];
  32. this.outputs = [];
  33. this.nodes = [];
  34. const values = new Map();
  35. values.map = (name, type, tensor) => {
  36. if (!values.has(name)) {
  37. type = tensor ? tensor.type : type;
  38. values.set(name, new barracuda.Value(name, type, tensor));
  39. } else if (type || tensor) {
  40. throw new barracuda.Error(`Duplicate value '${name}'.`);
  41. }
  42. return values.get(name);
  43. };
  44. const layers = [];
  45. for (const layer of model.layers) {
  46. if (layer.type !== 255 || layer.inputs.length > 0) {
  47. layers.push(layer);
  48. } else {
  49. for (const tensor of layer.tensors) {
  50. values.map(tensor.name, null, new barracuda.Tensor(tensor));
  51. }
  52. }
  53. }
  54. for (const input of model.inputs) {
  55. const shape = new barracuda.TensorShape(input.shape);
  56. const type = new barracuda.TensorType(4, shape);
  57. const argument = new barracuda.Argument(input.name, [values.map(input.name, type)]);
  58. this.inputs.push(argument);
  59. }
  60. for (const output of model.outputs) {
  61. const argument = new barracuda.Argument(output, [values.map(output)]);
  62. this.outputs.push(argument);
  63. }
  64. for (const layer of layers) {
  65. const node = new barracuda.Node(metadata, layer, null, values);
  66. this.nodes.push(node);
  67. }
  68. }
  69. };
  70. barracuda.Argument = class {
  71. constructor(name, value, type = null) {
  72. this.name = name;
  73. this.value = value;
  74. this.type = type;
  75. }
  76. };
  77. barracuda.Value = class {
  78. constructor(name, type = null, initializer = null) {
  79. this.name = name;
  80. this.type = type;
  81. this.initializer = initializer;
  82. }
  83. };
  84. barracuda.Node = class {
  85. constructor(metadata, layer, type, values) {
  86. this.name = layer.name || '';
  87. this.type = type ? type : metadata.type(layer.type);
  88. this.inputs = [];
  89. this.outputs = [];
  90. this.attributes = [];
  91. const inputs = Array.prototype.slice.call(this.type.inputs || ['input']);
  92. if (this.type.inputs && this.type.inputs.length === 1 && this.type.inputs[0].name === 'inputs') {
  93. const argument = new barracuda.Argument('inputs', layer.inputs.map((input) => values.map(input)));
  94. this.inputs.push(argument);
  95. } else if (layer.inputs) {
  96. for (let i = 0; i < layer.inputs.length; i++) {
  97. const input = layer.inputs[i];
  98. const name = inputs.length > 0 && inputs[0] ? inputs.shift().name : i.toString();
  99. const argument = new barracuda.Argument(name, [values.map(input)]);
  100. this.inputs.push(argument);
  101. }
  102. }
  103. if (layer.tensors) {
  104. for (let i = 0; i < layer.tensors.length; i++) {
  105. const tensor = layer.tensors[i];
  106. const initializer = new barracuda.Tensor(tensor);
  107. const name = inputs.length > 0 && inputs[0] ? inputs.shift().name : i.toString();
  108. const argument = new barracuda.Argument(name, [values.map(tensor.name, initializer.type, initializer)]);
  109. this.inputs.push(argument);
  110. }
  111. }
  112. if (layer.inputs !== undefined) {
  113. const argument = new barracuda.Argument('output', [values.map(this.name)]);
  114. this.outputs.push(argument);
  115. }
  116. if (layer.activation !== undefined && (layer.type === 50 || layer.activation !== 0)) {
  117. const type = barracuda.Activation[layer.activation];
  118. if (!type) {
  119. throw new barracuda.Error(`Unsupported activation '${layer.activation}'.`);
  120. }
  121. const node = new barracuda.Node(metadata, {}, { name: type, category: 'Activation' }, values);
  122. this.chain = [node];
  123. }
  124. const attributes = [
  125. ['strides', 'int32[]', []],
  126. ['pads', 'int32[]', (value) => Array.isArray(value) && (value.every((v) => v === 0) || value.every((v) => v === -1))],
  127. ['pool_size', 'int32[]', []],
  128. ['alpha', 'float32', 1],
  129. ['beta', 'float32', 0],
  130. ['axis', 'int32', -1]
  131. ];
  132. for (const [name, type, defaultValue] of attributes) {
  133. const value = layer[name];
  134. if ((value === undefined) ||
  135. (Array.isArray(defaultValue) && Array.isArray(value) && value.length === defaultValue.length && value.every((v, i) => v === defaultValue[i])) ||
  136. (typeof defaultValue === 'function' && defaultValue(value)) ||
  137. (defaultValue === value)) {
  138. continue;
  139. }
  140. const attribute = new barracuda.Argument(name, value, type);
  141. this.attributes.push(attribute);
  142. }
  143. }
  144. };
  145. barracuda.Tensor = class {
  146. constructor(tensor) {
  147. this.type = new barracuda.TensorType(tensor.itemsize, new barracuda.TensorShape(tensor.shape));
  148. this.values = tensor.data;
  149. }
  150. };
  151. barracuda.TensorType = class {
  152. constructor(itemsize, shape) {
  153. switch (itemsize) {
  154. case 4: this.dataType = 'float32'; break;
  155. default: throw new barracuda.Error(`Unsupported data type size '${itemsize}'.`);
  156. }
  157. this.shape = shape;
  158. }
  159. toString() {
  160. return this.dataType + this.shape.toString();
  161. }
  162. };
  163. barracuda.TensorShape = class {
  164. constructor(dimensions) {
  165. this.dimensions = dimensions;
  166. }
  167. toString() {
  168. return this.dimensions ? (`[${this.dimensions.map((dimension) => dimension ? dimension.toString() : '?').join(',')}]`) : '';
  169. }
  170. };
  171. barracuda.NNModel = class {
  172. constructor(reader) {
  173. // https://github.com/Unity-Technologies/barracuda-release/blob/release/1.3.2/Barracuda/Runtime/Core/Model.cs
  174. reader = new barracuda.BinaryReader(reader);
  175. this.version = reader.int32();
  176. reader.int32();
  177. this.inputs = new Array(reader.int32());
  178. for (let i = 0; i < this.inputs.length; i++) {
  179. this.inputs[i] = {
  180. name: reader.string(),
  181. shape: reader.shape()
  182. };
  183. }
  184. this.outputs = reader.strings();
  185. this.memories = new Array(reader.int32());
  186. for (let i = 0; i < this.memories.length; i++) {
  187. this.memories[i] = {
  188. shape: reader.shape(),
  189. in: reader.string(),
  190. out: reader.string()
  191. };
  192. }
  193. this.layers = new Array(reader.int32());
  194. for (let i = 0; i < this.layers.length; i++) {
  195. const layer = {};
  196. layer.name = reader.string();
  197. layer.type = reader.int32();
  198. layer.activation = reader.int32();
  199. reader.int32();
  200. reader.int32();
  201. layer.pads = reader.int32s();
  202. layer.strides = reader.int32s();
  203. layer.pool_size = reader.int32s();
  204. layer.axis = reader.int32();
  205. layer.alpha = reader.float32();
  206. layer.beta = reader.float32();
  207. reader.int32();
  208. layer.inputs = reader.strings();
  209. layer.tensors = [];
  210. const tensorsLength = reader.int32();
  211. for (let j = 0; j < tensorsLength; j++) {
  212. layer.tensors.push({
  213. name: reader.string(),
  214. shape: reader.shape(),
  215. offset: reader.int64().toNumber(),
  216. itemsize: reader.int32(),
  217. length: reader.int32()
  218. });
  219. }
  220. this.layers[i] = layer;
  221. }
  222. const position = reader.position;
  223. for (const layer of this.layers) {
  224. for (const tensor of layer.tensors) {
  225. const offset = tensor.offset;
  226. reader.seek(position + (offset * tensor.itemsize));
  227. tensor.data = reader.read(tensor.length * tensor.itemsize);
  228. }
  229. }
  230. }
  231. };
  232. barracuda.Activation = {
  233. 0: "Linear", 1: "Relu", 2: "Softmax", 3: "Tanh", 4: "Sigmoid", 5: "Elu", 6: "Relu6", 7: "LeakyRelu", 8: "Selu", 9: "Swish",
  234. 10: "LogSoftmax", 11: "Softplus", 12: "Softsign", 13: "PRelu",
  235. 20: "Hardmax", 21: "HardSigmoid",
  236. 100: "Abs", 101: "Neg", 102: "Ceil", 103: "Clip", 104: "Floor", 105: "Round",
  237. 110: "Reciprocal", 111: "Sqrt", 113: "Exp", 114: "Log",
  238. 200: "Acos", 201: "Acosh", 202: "Asin", 203: "Asinh", 204: "Atan", 205: "Atanh", 206: "Cos", 207: "Cosh", 208: "Sin", 209: "Sinh", 210: "Tan"
  239. };
  240. barracuda.BinaryReader = class {
  241. constructor(reader) {
  242. this._reader = reader;
  243. }
  244. get position() {
  245. return this._reader.position;
  246. }
  247. seek(position) {
  248. this._reader.seek(position);
  249. }
  250. skip(offset) {
  251. this._reader.skip(offset);
  252. }
  253. read(length) {
  254. return this._reader.read(length);
  255. }
  256. byte() {
  257. return this._reader.byte();
  258. }
  259. int32() {
  260. return this._reader.int32();
  261. }
  262. int32s() {
  263. const values = new Array(this.int32());
  264. for (let i = 0; i < values.length; i++) {
  265. values[i] = this.int32();
  266. }
  267. return values;
  268. }
  269. int64() {
  270. return this._reader.int64();
  271. }
  272. float32() {
  273. return this._reader.float32();
  274. }
  275. string() {
  276. let content = '';
  277. const size = this.int32();
  278. for (let i = 0; i < size; i++) {
  279. const c = this.byte();
  280. content += String.fromCharCode(c);
  281. }
  282. return content;
  283. }
  284. strings() {
  285. const values = [];
  286. const length = this.int32();
  287. for (let i = 0; i < length; i++) {
  288. values.push(this.string());
  289. }
  290. return values;
  291. }
  292. shape() {
  293. return this.int32s();
  294. }
  295. };
  296. barracuda.Metadata = class {
  297. static open() {
  298. barracuda.Metadata._metadata = barracuda.Metadata._metadata || new barracuda.Metadata();
  299. return barracuda.Metadata._metadata;
  300. }
  301. constructor() {
  302. this._types = new Map();
  303. const register = (id, name, category, inputs) => {
  304. this._types.set(id, { name, category, inputs: (inputs || []).map((input) => {
  305. return { name: input };
  306. }) });
  307. };
  308. register(0, 'Nop', '');
  309. register(1, 'Dense', 'Layer', ['input', 'kernel', 'bias']);
  310. register(2, 'MatMul', '', ['input', 'kernel', 'bias']);
  311. register(20, 'Conv2D', 'Layer', ['input', 'kernel', 'bias']);
  312. register(21, 'DepthwiseConv2D', 'Layer', ['input', 'kernel', 'bias']);
  313. register(22, 'Conv2DTrans', 'Layer', ['input', 'kernel', 'bias']);
  314. register(23, 'Upsample2D', 'Data');
  315. register(25, 'MaxPool2D', 'Pool');
  316. register(26, 'AvgPool2D', 'Pool');
  317. register(27, 'GlobalMaxPool2D', 'Pool');
  318. register(28, 'GlobalAvgPool2D', 'Pool');
  319. register(29, 'Border2D', '');
  320. register(30, 'Conv3D', 'Layer');
  321. register(32, 'Conv3DTrans', 'Layer');
  322. register(33, 'Upsample3D', 'Data');
  323. register(35, 'MaxPool3D', 'Pool');
  324. register(36, 'AvgPool3D', 'Pool');
  325. register(37, 'GlobalMaxPool3D', 'Pool');
  326. register(38, 'GlobalAvgPool3D', 'Pool');
  327. register(39, 'Border3D', '');
  328. register(50, 'Activation', '', ['input']);
  329. register(51, 'ScaleBias', 'Normalization', ['input', 'scale', 'bias']);
  330. register(52, 'Normalization', 'Normalization');
  331. register(53, 'LRN', 'Normalization');
  332. register(60, 'Dropout', 'Dropout');
  333. register(64, 'RandomNormal', '');
  334. register(65, 'RandomUniform', '');
  335. register(66, 'Multinomial', '');
  336. register(67, 'OneHot', '');
  337. register(68, 'TopKIndices', '');
  338. register(69, 'TopKValues', '');
  339. register(100, 'Add', '', ['inputs']);
  340. register(101, 'Sub', '', ['inputs']);
  341. register(102, 'Mul', '', ['inputs']);
  342. register(103, 'RealDiv', '', ['inputs']);
  343. register(104, 'Pow', '', ['inputs']);
  344. register(110, 'Minimum', '', ['inputs']);
  345. register(111, 'Maximum', '', ['inputs']);
  346. register(112, 'Mean', '', ['inputs']);
  347. register(120, 'ReduceL1', '', ['inputs']);
  348. register(121, 'ReduceL2', '', ['inputs']);
  349. register(122, 'ReduceLogSum', '', ['inputs']);
  350. register(123, 'ReduceLogSumExp', '', ['inputs']);
  351. register(124, 'ReduceMax', '', ['inputs']);
  352. register(125, 'ReduceMean', '', ['inputs']);
  353. register(126, 'ReduceMin', '', ['inputs']);
  354. register(127, 'ReduceProd', '', ['inputs']);
  355. register(128, 'ReduceSum', '', ['inputs']);
  356. register(129, 'ReduceSumSquare', '', ['inputs']);
  357. register(140, 'Greater', '');
  358. register(141, 'GreaterEqual', '');
  359. register(142, 'Less', '');
  360. register(143, 'LessEqual', '');
  361. register(144, 'Equal', '');
  362. register(145, 'LogicalOr', '');
  363. register(146, 'LogicalAnd', '');
  364. register(147, 'LogicalNot', '');
  365. register(148, 'LogicalXor', '');
  366. register(160, 'Pad2DReflect', '');
  367. register(161, 'Pad2DSymmetric', '');
  368. register(162, 'Pad2DEdge', '');
  369. register(200, 'Flatten', 'Shape');
  370. register(201, 'Reshape', 'Shape');
  371. register(202, 'Transpose', '');
  372. register(203, 'Squeeze', '');
  373. register(204, 'Unsqueeze', '');
  374. register(205, 'Gather', '');
  375. register(206, 'DepthToSpace', '');
  376. register(207, 'SpaceToDepth', '');
  377. register(208, 'Expand', '');
  378. register(209, 'Resample2D', '');
  379. register(210, 'Concat', 'Tensor', ['inputs']);
  380. register(211, 'StridedSlice', 'Shape');
  381. register(212, 'Tile', '');
  382. register(213, 'Shape', '');
  383. register(214, 'NonMaxSuppression', '');
  384. register(215, 'LSTM', '');
  385. register(255, 'Load', '');
  386. }
  387. type(name) {
  388. if (!this._types.has(name)) {
  389. this._types.set(name, { name: name.toString() });
  390. }
  391. return this._types.get(name);
  392. }
  393. };
  394. barracuda.Error = class extends Error {
  395. constructor(message) {
  396. super(message);
  397. this.name = 'Error loading Barracuda model.';
  398. }
  399. };
  400. export const ModelFactory = barracuda.ModelFactory;