asr_client_imp.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. // Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "asr_client_imp.h"
  15. #include <unistd.h>
  16. #include <cmath>
  17. #include <cstring>
  18. #include <iomanip>
  19. #include <numeric>
  20. #include <sstream>
  21. #include "lat/kaldi-lattice.h"
  22. #include "lat/lattice-functions.h"
  23. #include "util/kaldi-table.h"
  24. #define FAIL_IF_ERR(X, MSG) \
  25. { \
  26. nic::Error err = (X); \
  27. if (!err.IsOk()) { \
  28. std::cerr << "error: " << (MSG) << ": " << err << std::endl; \
  29. exit(1); \
  30. } \
  31. }
  32. void TritonASRClient::CreateClientContext() {
  33. clients_.emplace_back();
  34. TritonClient& client = clients_.back();
  35. FAIL_IF_ERR(nic::InferenceServerGrpcClient::Create(&client.triton_client,
  36. url_, /*verbose*/ false),
  37. "unable to create triton client");
  38. FAIL_IF_ERR(
  39. client.triton_client->StartStream(
  40. [&](nic::InferResult* result) {
  41. double end_timestamp = gettime_monotonic();
  42. std::unique_ptr<nic::InferResult> result_ptr(result);
  43. FAIL_IF_ERR(result_ptr->RequestStatus(),
  44. "inference request failed");
  45. std::string request_id;
  46. FAIL_IF_ERR(result_ptr->Id(&request_id),
  47. "unable to get request id for response");
  48. uint64_t corr_id =
  49. std::stoi(std::string(request_id, 0, request_id.find("_")));
  50. bool end_of_stream = (request_id.back() == '1');
  51. if (!end_of_stream) {
  52. if (print_partial_results_) {
  53. std::vector<std::string> text;
  54. FAIL_IF_ERR(result_ptr->StringData("TEXT", &text),
  55. "unable to get TEXT output");
  56. std::lock_guard<std::mutex> lk(stdout_m_);
  57. std::cout << "CORR_ID " << corr_id << "\t[partial]\t" << text[0]
  58. << '\n';
  59. }
  60. return;
  61. }
  62. double start_timestamp;
  63. {
  64. std::lock_guard<std::mutex> lk(start_timestamps_m_);
  65. auto it = start_timestamps_.find(corr_id);
  66. if (it != start_timestamps_.end()) {
  67. start_timestamp = it->second;
  68. start_timestamps_.erase(it);
  69. } else {
  70. std::cerr << "start_timestamp not found" << std::endl;
  71. exit(1);
  72. }
  73. }
  74. if (print_results_) {
  75. std::vector<std::string> text;
  76. FAIL_IF_ERR(result_ptr->StringData(ctm_ ? "CTM" : "TEXT", &text),
  77. "unable to get TEXT or CTM output");
  78. std::lock_guard<std::mutex> lk(stdout_m_);
  79. std::cout << "CORR_ID " << corr_id;
  80. std::cout << (ctm_ ? "\n" : "\t\t");
  81. std::cout << text[0] << std::endl;
  82. }
  83. std::vector<std::string> lattice_bytes;
  84. FAIL_IF_ERR(result_ptr->StringData("RAW_LATTICE", &lattice_bytes),
  85. "unable to get RAW_LATTICE output");
  86. {
  87. double elapsed = end_timestamp - start_timestamp;
  88. std::lock_guard<std::mutex> lk(results_m_);
  89. results_.insert(
  90. {corr_id, {std::move(lattice_bytes[0]), elapsed}});
  91. }
  92. n_in_flight_.fetch_sub(1, std::memory_order_relaxed);
  93. },
  94. false),
  95. "unable to establish a streaming connection to server");
  96. }
  97. void TritonASRClient::SendChunk(uint64_t corr_id, bool start_of_sequence,
  98. bool end_of_sequence, float* chunk,
  99. int chunk_byte_size, const uint64_t index) {
  100. // Setting options
  101. nic::InferOptions options(model_name_);
  102. options.sequence_id_ = corr_id;
  103. options.sequence_start_ = start_of_sequence;
  104. options.sequence_end_ = end_of_sequence;
  105. options.request_id_ = std::to_string(corr_id) + "_" + std::to_string(index) +
  106. "_" + (start_of_sequence ? "1" : "0") + "_" +
  107. (end_of_sequence ? "1" : "0");
  108. // Initialize the inputs with the data.
  109. nic::InferInput* wave_data_ptr;
  110. std::vector<int64_t> wav_shape{1, samps_per_chunk_};
  111. FAIL_IF_ERR(
  112. nic::InferInput::Create(&wave_data_ptr, "WAV_DATA", wav_shape, "FP32"),
  113. "unable to create 'WAV_DATA'");
  114. std::shared_ptr<nic::InferInput> wave_data_in(wave_data_ptr);
  115. FAIL_IF_ERR(wave_data_in->Reset(), "unable to reset 'WAV_DATA'");
  116. uint8_t* wave_data = reinterpret_cast<uint8_t*>(chunk);
  117. if (chunk_byte_size < max_chunk_byte_size_) {
  118. std::memcpy(&chunk_buf_[0], chunk, chunk_byte_size);
  119. wave_data = &chunk_buf_[0];
  120. }
  121. FAIL_IF_ERR(wave_data_in->AppendRaw(wave_data, max_chunk_byte_size_),
  122. "unable to set data for 'WAV_DATA'");
  123. // Dim
  124. nic::InferInput* dim_ptr;
  125. std::vector<int64_t> shape{1, 1};
  126. FAIL_IF_ERR(nic::InferInput::Create(&dim_ptr, "WAV_DATA_DIM", shape, "INT32"),
  127. "unable to create 'WAV_DATA_DIM'");
  128. std::shared_ptr<nic::InferInput> dim_in(dim_ptr);
  129. FAIL_IF_ERR(dim_in->Reset(), "unable to reset WAVE_DATA_DIM");
  130. int nsamples = chunk_byte_size / sizeof(float);
  131. FAIL_IF_ERR(
  132. dim_in->AppendRaw(reinterpret_cast<uint8_t*>(&nsamples), sizeof(int32_t)),
  133. "unable to set data for WAVE_DATA_DIM");
  134. std::vector<nic::InferInput*> inputs = {wave_data_in.get(), dim_in.get()};
  135. std::vector<const nic::InferRequestedOutput*> outputs;
  136. std::shared_ptr<nic::InferRequestedOutput> raw_lattice, text;
  137. outputs.reserve(2);
  138. if (end_of_sequence) {
  139. nic::InferRequestedOutput* raw_lattice_ptr;
  140. FAIL_IF_ERR(
  141. nic::InferRequestedOutput::Create(&raw_lattice_ptr, "RAW_LATTICE"),
  142. "unable to get 'RAW_LATTICE'");
  143. raw_lattice.reset(raw_lattice_ptr);
  144. outputs.push_back(raw_lattice.get());
  145. // Request the TEXT results only when required for printing
  146. if (print_results_) {
  147. nic::InferRequestedOutput* text_ptr;
  148. FAIL_IF_ERR(
  149. nic::InferRequestedOutput::Create(&text_ptr, ctm_ ? "CTM" : "TEXT"),
  150. "unable to get 'TEXT' or 'CTM'");
  151. text.reset(text_ptr);
  152. outputs.push_back(text.get());
  153. }
  154. } else if (print_partial_results_) {
  155. nic::InferRequestedOutput* text_ptr;
  156. FAIL_IF_ERR(nic::InferRequestedOutput::Create(&text_ptr, "TEXT"),
  157. "unable to get 'TEXT'");
  158. text.reset(text_ptr);
  159. outputs.push_back(text.get());
  160. }
  161. total_audio_ += (static_cast<double>(nsamples) / samp_freq_);
  162. if (start_of_sequence) {
  163. n_in_flight_.fetch_add(1, std::memory_order_consume);
  164. }
  165. // Record the timestamp when the last chunk was made available.
  166. if (end_of_sequence) {
  167. std::lock_guard<std::mutex> lk(start_timestamps_m_);
  168. start_timestamps_[corr_id] = gettime_monotonic();
  169. }
  170. TritonClient* client = &clients_[corr_id % nclients_];
  171. // nic::InferenceServerGrpcClient& triton_client = *client->triton_client;
  172. FAIL_IF_ERR(client->triton_client->AsyncStreamInfer(options, inputs, outputs),
  173. "unable to run model");
  174. }
  175. void TritonASRClient::WaitForCallbacks() {
  176. while (n_in_flight_.load(std::memory_order_consume)) {
  177. usleep(1000);
  178. }
  179. }
  180. void TritonASRClient::PrintStats(bool print_latency_stats,
  181. bool print_throughput) {
  182. double now = gettime_monotonic();
  183. double diff = now - started_at_;
  184. double rtf = total_audio_ / diff;
  185. if (print_throughput)
  186. std::cout << "Throughput:\t" << rtf << " RTFX" << std::endl;
  187. std::vector<double> latencies;
  188. {
  189. std::lock_guard<std::mutex> lk(results_m_);
  190. latencies.reserve(results_.size());
  191. for (auto& result : results_) latencies.push_back(result.second.latency);
  192. }
  193. std::sort(latencies.begin(), latencies.end());
  194. double nresultsf = static_cast<double>(latencies.size());
  195. size_t per90i = static_cast<size_t>(std::floor(90. * nresultsf / 100.));
  196. size_t per95i = static_cast<size_t>(std::floor(95. * nresultsf / 100.));
  197. size_t per99i = static_cast<size_t>(std::floor(99. * nresultsf / 100.));
  198. double lat_90 = latencies[per90i];
  199. double lat_95 = latencies[per95i];
  200. double lat_99 = latencies[per99i];
  201. double avg = std::accumulate(latencies.begin(), latencies.end(), 0.0) /
  202. latencies.size();
  203. std::cout << std::setprecision(3);
  204. std::cout << "Latencies:\t90%\t\t95%\t\t99%\t\tAvg\n";
  205. if (print_latency_stats) {
  206. std::cout << "\t\t" << lat_90 << "\t\t" << lat_95 << "\t\t" << lat_99
  207. << "\t\t" << avg << std::endl;
  208. } else {
  209. std::cout << "\t\tN/A\t\tN/A\t\tN/A\t\tN/A" << std::endl;
  210. std::cout << "Latency statistics are printed only when the "
  211. "online option is set (-o)."
  212. << std::endl;
  213. }
  214. }
  215. TritonASRClient::TritonASRClient(const std::string& url,
  216. const std::string& model_name,
  217. const int nclients, bool print_results,
  218. bool print_partial_results, bool ctm,
  219. float samp_freq)
  220. : url_(url),
  221. model_name_(model_name),
  222. nclients_(nclients),
  223. print_results_(print_results),
  224. print_partial_results_(print_partial_results),
  225. ctm_(ctm),
  226. samp_freq_(samp_freq) {
  227. nclients_ = std::max(nclients_, 1);
  228. for (int i = 0; i < nclients_; ++i) CreateClientContext();
  229. inference::ModelMetadataResponse model_metadata;
  230. FAIL_IF_ERR(
  231. clients_[0].triton_client->ModelMetadata(&model_metadata, model_name),
  232. "unable to get model metadata");
  233. for (const auto& in_tensor : model_metadata.inputs()) {
  234. if (in_tensor.name().compare("WAV_DATA") == 0) {
  235. samps_per_chunk_ = in_tensor.shape()[1];
  236. }
  237. }
  238. max_chunk_byte_size_ = samps_per_chunk_ * sizeof(float);
  239. chunk_buf_.resize(max_chunk_byte_size_);
  240. shape_ = {max_chunk_byte_size_};
  241. n_in_flight_.store(0);
  242. started_at_ = gettime_monotonic();
  243. total_audio_ = 0;
  244. }
  245. void TritonASRClient::WriteLatticesToFile(
  246. const std::string& clat_wspecifier,
  247. const std::unordered_map<uint64_t, std::string>& corr_id_and_keys) {
  248. kaldi::CompactLatticeWriter clat_writer;
  249. clat_writer.Open(clat_wspecifier);
  250. std::unordered_map<std::string, size_t> key_count;
  251. std::lock_guard<std::mutex> lk(results_m_);
  252. for (auto& p : corr_id_and_keys) {
  253. uint64_t corr_id = p.first;
  254. std::string key = p.second;
  255. const auto iter = key_count[key]++;
  256. if (iter > 0) {
  257. key += std::to_string(iter);
  258. }
  259. auto it = results_.find(corr_id);
  260. if (it == results_.end()) {
  261. std::cerr << "Cannot find lattice for corr_id " << corr_id << std::endl;
  262. continue;
  263. }
  264. const std::string& raw_lattice = it->second.raw_lattice;
  265. // We could in theory write directly the binary hold in raw_lattice (it is
  266. // in the kaldi lattice format) However getting back to a CompactLattice
  267. // object allows us to us CompactLatticeWriter
  268. std::istringstream iss(raw_lattice);
  269. kaldi::CompactLattice* clat = NULL;
  270. kaldi::ReadCompactLattice(iss, true, &clat);
  271. clat_writer.Write(key, *clat);
  272. }
  273. clat_writer.Close();
  274. }