Procházet zdrojové kódy

[Kaldi] Adding Jupyter notebook

kkudrynski před 5 roky
rodič
revize
799660f2d5

+ 4 - 7
Kaldi/SpeechRecognition/Dockerfile

@@ -11,13 +11,10 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-FROM nvcr.io/nvidia/kaldi:19.12-online-beta-py3 as kb
+FROM nvcr.io/nvidia/kaldi:20.03-py3 as kb
+FROM nvcr.io/nvidia/tritonserver:20.03-py3
 ENV DEBIAN_FRONTEND=noninteractive
 ENV DEBIAN_FRONTEND=noninteractive
 
 
-ARG PYVER=3.6
-
-FROM nvcr.io/nvidia/tensorrtserver:19.12-py3
-
 # Kaldi dependencies
 # Kaldi dependencies
 RUN apt-get update && apt-get install -y --no-install-recommends \
 RUN apt-get update && apt-get install -y --no-install-recommends \
         automake \
         automake \
@@ -27,8 +24,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
         gawk \
         gawk \
         libatlas3-base \
         libatlas3-base \
         libtool \
         libtool \
-        python$PYVER \
-        python$PYVER-dev \
+        python3.6 \
+        python3.6-dev \
         sox \
         sox \
         subversion \
         subversion \
         unzip \
         unzip \

+ 6 - 4
Kaldi/SpeechRecognition/Dockerfile.client

@@ -11,8 +11,8 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-FROM nvcr.io/nvidia/kaldi:19.12-online-beta-py3 as kb
-FROM nvcr.io/nvidia/tensorrtserver:19.12-py3-clientsdk
+FROM nvcr.io/nvidia/kaldi:20.03-py3 as kb
+FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk
 
 
 # Kaldi dependencies
 # Kaldi dependencies
 RUN apt-get update && apt-get install -y --no-install-recommends \
 RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
         gawk \
         gawk \
         libatlas3-base \
         libatlas3-base \
         libtool \
         libtool \
-        python$PYVER \
-        python$PYVER-dev \
+        python3.6 \
+        python3.6-dev \
         sox \
         sox \
         subversion \
         subversion \
         unzip \
         unzip \
@@ -36,6 +36,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
 COPY --from=kb /opt/kaldi /opt/kaldi
 COPY --from=kb /opt/kaldi /opt/kaldi
 ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:$LD_LIBRARY_PATH
 ENV LD_LIBRARY_PATH /opt/kaldi/src/lib/:$LD_LIBRARY_PATH
 
 
+COPY scripts /workspace/scripts
+
 COPY kaldi-asr-client /workspace/src/clients/c++/kaldi-asr-client
 COPY kaldi-asr-client /workspace/src/clients/c++/kaldi-asr-client
 RUN echo "add_subdirectory(kaldi-asr-client)" >> "/workspace/src/clients/c++/CMakeLists.txt"
 RUN echo "add_subdirectory(kaldi-asr-client)" >> "/workspace/src/clients/c++/CMakeLists.txt"
 RUN cd /workspace/build/ && make -j16 trtis-clients
 RUN cd /workspace/build/ && make -j16 trtis-clients

+ 31 - 0
Kaldi/SpeechRecognition/Dockerfile.notebook

@@ -0,0 +1,31 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM nvcr.io/nvidia/tritonserver:20.03-py3-clientsdk
+
+# Kaldi dependencies
+RUN apt-get update && apt-get install -y jupyter \
+                   python3-pyaudio \
+                   python-pyaudio \
+                   libasound-dev \
+                   portaudio19-dev \
+                   libportaudio2 \
+                   libportaudiocpp0 \
+                   libsndfile1 \
+                   alsa-base \
+                   alsa-utils \
+                   vim
+
+RUN python3 -m pip uninstall -y pip
+RUN apt install python3-pip --reinstall
+RUN pip3 install matplotlib soundfile librosa sounddevice

+ 22 - 17
Kaldi/SpeechRecognition/README.md

@@ -1,6 +1,6 @@
-# Kaldi ASR Integration With TensorRT Inference Server
+# Kaldi ASR Integration With Triton
 
 
-This repository provides a Kaldi ASR custom backend for the NVIDIA TensorRT Inference Server (TRTIS). It can be used to demonstrate high-performance online inference on Kaldi ASR models. This includes handling the gRPC communication between the TensorRT Inference Server and clients, and the dynamic batching of inference requests. This repository is tested and maintained by NVIDIA.
+This repository provides a Kaldi ASR custom backend for the NVIDIA Triton (former TensorRT Inference Server). It can be used to demonstrate high-performance online inference on Kaldi ASR models. This includes handling the gRPC communication between the Triton and clients, and the dynamic batching of inference requests. This repository is tested and maintained by NVIDIA.
 
 
 ## Table Of Contents
 ## Table Of Contents
 
 
@@ -33,9 +33,9 @@ This repository provides a Kaldi ASR custom backend for the NVIDIA TensorRT Infe
 
 
 This repository provides a wrapper around the online GPU-accelerated ASR pipeline from the paper [GPU-Accelerated Viterbi Exact Lattice Decoder for Batched Online and Offline Speech Recognition](https://arxiv.org/abs/1910.10032). That work includes a high-performance implementation of a GPU HMM Decoder, a low-latency Neural Net driver, fast Feature Extraction for preprocessing, and new ASR pipelines tailored for GPUs. These different modules have been integrated into the Kaldi ASR framework.
 This repository provides a wrapper around the online GPU-accelerated ASR pipeline from the paper [GPU-Accelerated Viterbi Exact Lattice Decoder for Batched Online and Offline Speech Recognition](https://arxiv.org/abs/1910.10032). That work includes a high-performance implementation of a GPU HMM Decoder, a low-latency Neural Net driver, fast Feature Extraction for preprocessing, and new ASR pipelines tailored for GPUs. These different modules have been integrated into the Kaldi ASR framework.
 
 
-This repository contains a TensorRT Inference Server custom backend for the Kaldi ASR framework. This custom backend calls the high-performance online GPU pipeline from the Kaldi ASR framework. This TensorRT Inference Server integration provides ease-of-use to Kaldi ASR inference: gRPC streaming server, dynamic sequence batching, and multi-instances support. A client connects to the gRPC server, streams audio by sending chunks to the server, and gets back the inferred text as an answer (see [Input/Output](#input-output)). More information about the TensorRT Inference Server can be found [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/).  
+This repository contains a Triton custom backend for the Kaldi ASR framework. This custom backend calls the high-performance online GPU pipeline from the Kaldi ASR framework. This Triton integration provides ease-of-use to Kaldi ASR inference: gRPC streaming server, dynamic sequence batching, and multi-instances support. A client connects to the gRPC server, streams audio by sending chunks to the server, and gets back the inferred text as an answer (see [Input/Output](#input-output)). More information about the Triton can be found [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/).  
 
 
-This TensorRT Inference Server integration is meant to be used with the LibriSpeech model for demonstration purposes. We include a pre-trained version of this model to allow you to easily test this work (see [Quick Start Guide](#quick-start-guide)). Both the TensorRT Inference Server integration and the underlying Kaldi ASR online GPU pipeline are a work in progress and will support more functionalities in the future. This includes online iVectors not currently supported in the Kaldi ASR GPU online pipeline and being replaced by a zero vector (see [Known issues](#known-issues)). Support for a custom Kaldi model is experimental (see [Using a custom Kaldi model](#using-custom-kaldi-model)).
+This Triton integration is meant to be used with the LibriSpeech model for demonstration purposes. We include a pre-trained version of this model to allow you to easily test this work (see [Quick Start Guide](#quick-start-guide)). Both the Triton integration and the underlying Kaldi ASR online GPU pipeline are a work in progress and will support more functionalities in the future. Support for a custom Kaldi model is experimental (see [Using a custom Kaldi model](#using-custom-kaldi-model)).
 
 
 ### Reference model
 ### Reference model
 
 
@@ -60,7 +60,7 @@ Details about parameters can be found in the [Parameters](#parameters) section.
 
 
 ### Requirements 
 ### Requirements 
 
 
-This repository contains Dockerfiles which extends the Kaldi and TensorRT Inference Server NVIDIA GPU Cloud (NGC) containers and encapsulates some dependencies. Aside from these dependencies, ensure you have [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker) installed.
+This repository contains Dockerfiles which extends the Kaldi and Triton NVIDIA GPU Cloud (NGC) containers and encapsulates some dependencies. Aside from these dependencies, ensure you have [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker) installed.
 
 
 
 
 For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
 For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
@@ -108,7 +108,7 @@ The following command will stream 1000 parallel streams to the server. The `-p`
 
 
 ### Parameters
 ### Parameters
 
 
-The configuration is done through the `config.pbtxt` file available in `model-repo/` directory. It allows you to specify the following:
+The configuration is done through the `config.pbtxt` file available in the `model-repo/kaldi_online/` directory. It allows you to specify the following:
 
 
 ####  Model path
 ####  Model path
 
 
@@ -141,7 +141,7 @@ The inference engine configuration parameters configure the inference engine. Th
 
 
 ### Inference process
 ### Inference process
 
 
-Inference is done through simulating concurrent users. Each user is attributed to one utterance from the LibriSpeech dataset. It streams that utterance by cutting it into chunks and gets the final `TEXT` output once the final chunk has been sent. A parameter sets the number of active users being simulated in parallel.  
+Inference is done through simulating concurrent users. Each user is attributed to one utterance from the LibriSpeech dataset. It streams that utterance by cutting it into chunks and gets the final `TEXT` output once the final chunk has been sent. The `-c` parameter sets the number of active users being simulated in parallel.  
 
 
 ### Client command-line parameters
 ### Client command-line parameters
 
 
@@ -187,7 +187,8 @@ Even if only the best path is used, we are still generating a full lattice for b
 
 
 Support for Kaldi ASR models that are different from the provided LibriSpeech model is experimental. However, it is possible to modify the [Model Path](#model-path) section of the config file `model-repo/kaldi_online/config.pbtxt` to set up your own model. 
 Support for Kaldi ASR models that are different from the provided LibriSpeech model is experimental. However, it is possible to modify the [Model Path](#model-path) section of the config file `model-repo/kaldi_online/config.pbtxt` to set up your own model. 
 
 
-The models and Kaldi allocators are currently not shared between instances. This means that if your model is large, you may end up with not enough memory on the GPU to store two different instances. If that's the case, you can set `count` to `1` in the `instance_group` section of the config file.
+The models and Kaldi allocators are currently not shared between instances. This means that if your model is large, you may end up with not enough memory on the GPU to store two different instances. If that's the case, 
+you can set `count` to `1` in the [`instance_group` section](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/model_configuration.html#instance-groups) of the config file.
 
 
 ## Performance
 ## Performance
 
 
@@ -218,16 +219,17 @@ Our results were obtained by:
 1. Building and starting the server as described in [Quick Start Guide](#quick-start-guide).
 1. Building and starting the server as described in [Quick Start Guide](#quick-start-guide).
 2. Running  `scripts/run_inference_all_v100.sh` and  `scripts/run_inference_all_t4.sh`
 2. Running  `scripts/run_inference_all_v100.sh` and  `scripts/run_inference_all_t4.sh`
 
 
+
 | GPU | Realtime I/O | Number of parallel audio channels | Throughput (RTFX) | Latency | | | |
 | GPU | Realtime I/O | Number of parallel audio channels | Throughput (RTFX) | Latency | | | |
 | ------ | ------ | ------ | ------ | ------ | ------ | ------ |------ |
 | ------ | ------ | ------ | ------ | ------ | ------ | ------ |------ |
 | | | | | 90% | 95% | 99% | Avg |
 | | | | | 90% | 95% | 99% | Avg |
-| V100 | No | 2000 | 1769.8 | N/A | N/A | N/A | N/A |
-| V100 | Yes | 1500 |  1220 | 0.424 | 0.473 | 0.758 | 0.345 |
-| V100 | Yes | 1000 |  867.4 | 0.358 | 0.405 | 0.707 | 0.276 |
-| V100 | Yes | 800 |  647.8 | 0.304 | 0.325 | 0.517 | 0.238 |
-| T4 | No | 1000 | 906.7 | N/A | N/A | N/A| N/A |
-| T4 | Yes | 700 | 629.6 | 0.629 | 0.782 | 1.01 | 0.463 |
-| T4 | Yes | 400 | 373.7 | 0.417 | 0.441 | 0.690 | 0.349 |
+| V100 | No | 2000 | 1506.5 | N/A | N/A | N/A | N/A |
+| V100 | Yes | 1500 |  1243.2 | 0.582 | 0.699 | 1.04 | 0.400 |
+| V100 | Yes | 1000 |  884.1 | 0.379 | 0.393 | 0.788 | 0.333 |
+| V100 | Yes | 800 |  660.2 | 0.334 | 0.340 | 0.438 | 0.288 |
+| T4 | No | 1000 | 675.2 | N/A | N/A | N/A| N/A |
+| T4 | Yes | 700 | 629.2 | 0.945 | 1.08 | 1.27 | 0.645 |
+| T4 | Yes | 400 | 373.7 | 0.579 | 0.624 | 0.758 | 0.452 |
 
 
 ## Release notes
 ## Release notes
 
 
@@ -236,6 +238,9 @@ Our results were obtained by:
 January 2020
 January 2020
 * Initial release
 * Initial release
 
 
-### Known issues
+April 2020
+* Printing WER accuracy in Triton client
+* Using the latest Kaldi GPU ASR pipeline, extended support for features (ivectors, fbanks)
 
 
-Only mfcc features are supported at this time. The reference model used in the benchmark scripts requires both mfcc and iVector features to deliver the best accuracy. Support for iVector features will be added in a future release.
+### Known issues
+* No multi-gpu support for the Triton integration

+ 5 - 0
Kaldi/SpeechRecognition/kaldi-asr-client/CMakeLists.txt

@@ -32,6 +32,7 @@ target_include_directories(
   /opt/kaldi/src/
   /opt/kaldi/src/
 )
 )
 
 
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") # openfst yields many warnings
 target_include_directories(
 target_include_directories(
   kaldi_asr_parallel_client 
   kaldi_asr_parallel_client 
   PRIVATE
   PRIVATE
@@ -58,6 +59,10 @@ target_link_libraries(
   PRIVATE /opt/kaldi/src/lib/libkaldi-base.so
   PRIVATE /opt/kaldi/src/lib/libkaldi-base.so
 )
 )
 
 
+target_link_libraries(
+  kaldi_asr_parallel_client
+  PRIVATE /opt/kaldi/src/lat/libkaldi-lat.so
+)
 
 
 install(
 install(
   TARGETS kaldi_asr_parallel_client
   TARGETS kaldi_asr_parallel_client

+ 67 - 19
Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.cc

@@ -18,6 +18,11 @@
 #include <cstring>
 #include <cstring>
 #include <iomanip>
 #include <iomanip>
 #include <numeric>
 #include <numeric>
+#include <sstream>
+
+#include "lat/kaldi-lattice.h"
+#include "lat/lattice-functions.h"
+#include "util/kaldi-table.h"
 
 
 #define FAIL_IF_ERR(X, MSG)                                        \
 #define FAIL_IF_ERR(X, MSG)                                        \
   {                                                                \
   {                                                                \
@@ -31,11 +36,12 @@
 void TRTISASRClient::CreateClientContext() {
 void TRTISASRClient::CreateClientContext() {
   contextes_.emplace_back();
   contextes_.emplace_back();
   ClientContext& client = contextes_.back();
   ClientContext& client = contextes_.back();
-  FAIL_IF_ERR(nic::InferGrpcStreamContext::Create(
-                  &client.trtis_context, /*corr_id*/ -1, url_, model_name_,
-                  /*model_version*/ -1,
-                  /*verbose*/ false),
-              "unable to create context");
+  FAIL_IF_ERR(
+      nic::InferGrpcStreamContext::Create(&client.trtis_context,
+                                          /*corr_id*/ -1, url_, model_name_,
+                                          /*model_version*/ -1,
+                                          /*verbose*/ false),
+      "unable to create context");
 }
 }
 
 
 void TRTISASRClient::SendChunk(ni::CorrelationID corr_id,
 void TRTISASRClient::SendChunk(ni::CorrelationID corr_id,
@@ -59,6 +65,8 @@ void TRTISASRClient::SendChunk(ni::CorrelationID corr_id,
     options->SetFlag(ni::InferRequestHeader::FLAG_SEQUENCE_END,
     options->SetFlag(ni::InferRequestHeader::FLAG_SEQUENCE_END,
                      end_of_sequence);
                      end_of_sequence);
     for (const auto& output : context.Outputs()) {
     for (const auto& output : context.Outputs()) {
+      if (output->Name() == "TEXT" && !print_results_)
+        continue;  // no need for text output if not printing
       options->AddRawResult(output);
       options->AddRawResult(output);
     }
     }
   }
   }
@@ -89,27 +97,33 @@ void TRTISASRClient::SendChunk(ni::CorrelationID corr_id,
   total_audio_ += (static_cast<double>(nsamples) / 16000.);  // TODO freq
   total_audio_ += (static_cast<double>(nsamples) / 16000.);  // TODO freq
   double start = gettime_monotonic();
   double start = gettime_monotonic();
   FAIL_IF_ERR(context.AsyncRun([corr_id, end_of_sequence, start, this](
   FAIL_IF_ERR(context.AsyncRun([corr_id, end_of_sequence, start, this](
-                  nic::InferContext* ctx,
-                  const std::shared_ptr<nic::InferContext::Request>& request) {
+                                   nic::InferContext* ctx,
+                                   const std::shared_ptr<
+                                       nic::InferContext::Request>& request) {
     if (end_of_sequence) {
     if (end_of_sequence) {
       double elapsed = gettime_monotonic() - start;
       double elapsed = gettime_monotonic() - start;
-      std::string out;
       std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
       std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
       ctx->GetAsyncRunResults(request, &results);
       ctx->GetAsyncRunResults(request, &results);
 
 
-      if (results.size() != 1) {
-        std::cerr << "Warning: Could not read output for corr_id " << corr_id
-                  << std::endl;
+      if (results.empty()) {
+        std::cerr << "Warning: Could not read "
+                     "output for corr_id "
+                  << corr_id << std::endl;
       } else {
       } else {
-        FAIL_IF_ERR(results["TEXT"]->GetRawAtCursor(0, &out),
-                    "unable to get TEXT output");
         if (print_results_) {
         if (print_results_) {
+	  std::string text;
+	  FAIL_IF_ERR(results["TEXT"]->GetRawAtCursor(0, &text),
+			  "unable to get TEXT output");
           std::lock_guard<std::mutex> lk(stdout_m_);
           std::lock_guard<std::mutex> lk(stdout_m_);
-          std::cout << "CORR_ID " << corr_id << "\t\t" << out << std::endl;
+          std::cout << "CORR_ID " << corr_id << "\t\t" << text << std::endl;
         }
         }
+
+        std::string lattice_bytes;
+        FAIL_IF_ERR(results["RAW_LATTICE"]->GetRawAtCursor(0, &lattice_bytes),
+                    "unable to get RAW_LATTICE output");
         {
         {
           std::lock_guard<std::mutex> lk(results_m_);
           std::lock_guard<std::mutex> lk(results_m_);
-          results_.insert({corr_id, {std::move(out), elapsed}});
+          results_.insert({corr_id, {std::move(lattice_bytes), elapsed}});
         }
         }
       }
       }
       n_in_flight_.fetch_sub(1, std::memory_order_relaxed);
       n_in_flight_.fetch_sub(1, std::memory_order_relaxed);
@@ -125,7 +139,7 @@ void TRTISASRClient::WaitForCallbacks() {
   }
   }
 }
 }
 
 
-void TRTISASRClient::PrintStats() {
+void TRTISASRClient::PrintStats(bool print_latency_stats) {
   double now = gettime_monotonic();
   double now = gettime_monotonic();
   double diff = now - started_at_;
   double diff = now - started_at_;
   double rtf = total_audio_ / diff;
   double rtf = total_audio_ / diff;
@@ -150,9 +164,16 @@ void TRTISASRClient::PrintStats() {
                latencies.size();
                latencies.size();
 
 
   std::cout << std::setprecision(3);
   std::cout << std::setprecision(3);
-  std::cout << "Latencies:\t90\t\t95\t\t99\t\tAvg\n";
-  std::cout << "\t\t" << lat_90 << "\t\t" << lat_95 << "\t\t" << lat_99
-            << "\t\t" << avg << std::endl;
+  std::cout << "Latencies:\t90%\t\t95%\t\t99%\t\tAvg\n";
+  if (print_latency_stats) {
+    std::cout << "\t\t" << lat_90 << "\t\t" << lat_95 << "\t\t" << lat_99
+              << "\t\t" << avg << std::endl;
+  } else {
+    std::cout << "\t\tN/A\t\tN/A\t\tN/A\t\tN/A" << std::endl;
+    std::cout << "Latency statistics are printed only when the "
+                 "online option is set (-o)."
+              << std::endl;
+  }
 }
 }
 
 
 TRTISASRClient::TRTISASRClient(const std::string& url,
 TRTISASRClient::TRTISASRClient(const std::string& url,
@@ -175,3 +196,30 @@ TRTISASRClient::TRTISASRClient(const std::string& url,
   started_at_ = gettime_monotonic();
   started_at_ = gettime_monotonic();
   total_audio_ = 0;
   total_audio_ = 0;
 }
 }
+
+void TRTISASRClient::WriteLatticesToFile(
+    const std::string& clat_wspecifier,
+    const std::unordered_map<ni::CorrelationID, std::string>&
+        corr_id_and_keys) {
+  kaldi::CompactLatticeWriter clat_writer;
+  clat_writer.Open(clat_wspecifier);
+  std::lock_guard<std::mutex> lk(results_m_);
+  for (auto& p : corr_id_and_keys) {
+    ni::CorrelationID corr_id = p.first;
+    const std::string& key = p.second;
+    auto it = results_.find(corr_id);
+    if(it == results_.end()) {
+	    std::cerr << "Cannot find lattice for corr_id " << corr_id << std::endl;
+	    continue;
+    }
+    const std::string& raw_lattice = it->second.raw_lattice;
+    // We could in theory write directly the binary hold in raw_lattice (it is
+    // in the kaldi lattice format) However getting back to a CompactLattice
+    // object allows us to us CompactLatticeWriter
+    std::istringstream iss(raw_lattice);
+    kaldi::CompactLattice* clat = NULL;
+    kaldi::ReadCompactLattice(iss, true, &clat);
+    clat_writer.Write(key, *clat);
+  }
+  clat_writer.Close();
+}

+ 4 - 2
Kaldi/SpeechRecognition/kaldi-asr-client/asr_client_imp.h

@@ -15,6 +15,7 @@
 #include <queue>
 #include <queue>
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
+#include <unordered_map>
 
 
 #include "request_grpc.h"
 #include "request_grpc.h"
 
 
@@ -52,7 +53,7 @@ class TRTISASRClient {
   std::mutex stdout_m_;
   std::mutex stdout_m_;
 
 
   struct Result {
   struct Result {
-    std::string text;
+    std::string raw_lattice;
     double latency;
     double latency;
   };
   };
 
 
@@ -64,7 +65,8 @@ class TRTISASRClient {
   void SendChunk(uint64_t corr_id, bool start_of_sequence, bool end_of_sequence,
   void SendChunk(uint64_t corr_id, bool start_of_sequence, bool end_of_sequence,
                  float* chunk, int chunk_byte_size);
                  float* chunk, int chunk_byte_size);
   void WaitForCallbacks();
   void WaitForCallbacks();
-  void PrintStats();
+  void PrintStats(bool print_latency_stats);
+  void WriteLatticesToFile(const std::string &clat_wspecifier, const std::unordered_map<ni::CorrelationID, std::string> &corr_id_and_keys);
 
 
   TRTISASRClient(const std::string& url, const std::string& model_name,
   TRTISASRClient(const std::string& url, const std::string& model_name,
                  const int ncontextes, bool print_results);
                  const int ncontextes, bool print_results);

+ 56 - 45
Kaldi/SpeechRecognition/kaldi-asr-client/kaldi_asr_parallel_client.cc

@@ -114,7 +114,7 @@ int main(int argc, char** argv) {
   std::cout << "Number of iterations\t\t: " << niterations << std::endl;
   std::cout << "Number of iterations\t\t: " << niterations << std::endl;
   std::cout << "Number of parallel channels\t: " << nchannels << std::endl;
   std::cout << "Number of parallel channels\t: " << nchannels << std::endl;
   std::cout << "Server URL\t\t\t: " << url << std::endl;
   std::cout << "Server URL\t\t\t: " << url << std::endl;
-  std::cout << "Print results\t\t\t: " << (print_results ? "Yes" : "No")
+  std::cout << "Print text outputs\t\t: " << (print_results ? "Yes" : "No")
             << std::endl;
             << std::endl;
   std::cout << "Online - Realtime I/O\t\t: " << (online ? "Yes" : "No")
   std::cout << "Online - Realtime I/O\t\t: " << (online ? "Yes" : "No")
             << std::endl;
             << std::endl;
@@ -124,13 +124,11 @@ int main(int argc, char** argv) {
   // need to read wav files
   // need to read wav files
   SequentialTableReader<WaveHolder> wav_reader(wav_rspecifier);
   SequentialTableReader<WaveHolder> wav_reader(wav_rspecifier);
 
 
-  std::atomic<uint64_t> correlation_id;
-  correlation_id.store(1);  // 0 = no correlation
-
   double total_audio = 0;
   double total_audio = 0;
   // pre-loading data
   // pre-loading data
   // we don't want to measure I/O
   // we don't want to measure I/O
   std::vector<std::shared_ptr<WaveData>> all_wav;
   std::vector<std::shared_ptr<WaveData>> all_wav;
+  std::vector<std::string> all_wav_keys;
   {
   {
     std::cout << "Loading eval dataset..." << std::flush;
     std::cout << "Loading eval dataset..." << std::flush;
     for (; !wav_reader.Done(); wav_reader.Next()) {
     for (; !wav_reader.Done(); wav_reader.Next()) {
@@ -138,6 +136,7 @@ int main(int argc, char** argv) {
       std::shared_ptr<WaveData> wave_data = std::make_shared<WaveData>();
       std::shared_ptr<WaveData> wave_data = std::make_shared<WaveData>();
       wave_data->Swap(&wav_reader.Value());
       wave_data->Swap(&wav_reader.Value());
       all_wav.push_back(wave_data);
       all_wav.push_back(wave_data);
+      all_wav_keys.push_back(utt);
       total_audio += wave_data->Duration();
       total_audio += wave_data->Duration();
     }
     }
     std::cout << "done" << std::endl;
     std::cout << "done" << std::endl;
@@ -164,55 +163,67 @@ int main(int argc, char** argv) {
   next_tasks.reserve(nchannels);
   next_tasks.reserve(nchannels);
   size_t all_wav_i = 0;
   size_t all_wav_i = 0;
   size_t all_wav_max = all_wav.size() * niterations;
   size_t all_wav_max = all_wav.size() * niterations;
+
   while (true) {
   while (true) {
-      while (curr_tasks.size() < nchannels && all_wav_i < all_wav_max) {
-        // Creating new tasks
-        uint64_t corr_id = correlation_id.fetch_add(1);
-        std::unique_ptr<Stream> ptr(new Stream(all_wav[all_wav_i%(all_wav.size())], corr_id));
-        curr_tasks.emplace_back(std::move(ptr));
-        ++all_wav_i;
-      }
-      // If still empty, done
-      if (curr_tasks.empty()) break;
-
-      for (size_t itask = 0; itask < curr_tasks.size(); ++itask) {
-        Stream& task = *(curr_tasks[itask]);
-
-        SubVector<BaseFloat> data(task.wav->Data(), 0);
-        int32 samp_offset = task.offset;
-        int32 nsamp = data.Dim();
-        int32 samp_remaining = nsamp - samp_offset;
-        int32 num_samp =
-            chunk_length < samp_remaining ? chunk_length : samp_remaining;
-        bool is_last_chunk = (chunk_length >= samp_remaining);
-        SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
-        bool is_first_chunk = (samp_offset == 0);
-        if (online) {
-          double now = gettime_monotonic();
-          double wait_for = task.send_next_chunk_at - now;
-          if (wait_for > 0) usleep(wait_for * 1e6);
-        }
-        asr_client.SendChunk(task.corr_id, is_first_chunk, is_last_chunk,
-                             wave_part.Data(), wave_part.SizeInBytes());
-        task.send_next_chunk_at += chunk_seconds;
-        if (verbose)
-          std::cout << "Sending correlation_id=" << task.corr_id
-                    << " chunk offset=" << num_samp << std::endl;
-
-        task.offset += num_samp;
-        if (!is_last_chunk) next_tasks.push_back(std::move(curr_tasks[itask]));
+    while (curr_tasks.size() < nchannels && all_wav_i < all_wav_max) {
+      // Creating new tasks
+      uint64_t corr_id = static_cast<uint64_t>(all_wav_i) + 1;
+
+      std::unique_ptr<Stream> ptr(
+          new Stream(all_wav[all_wav_i % (all_wav.size())], corr_id));
+      curr_tasks.emplace_back(std::move(ptr));
+      ++all_wav_i;
+    }
+    // If still empty, done
+    if (curr_tasks.empty()) break;
+
+    for (size_t itask = 0; itask < curr_tasks.size(); ++itask) {
+      Stream& task = *(curr_tasks[itask]);
+
+      SubVector<BaseFloat> data(task.wav->Data(), 0);
+      int32 samp_offset = task.offset;
+      int32 nsamp = data.Dim();
+      int32 samp_remaining = nsamp - samp_offset;
+      int32 num_samp =
+          chunk_length < samp_remaining ? chunk_length : samp_remaining;
+      bool is_last_chunk = (chunk_length >= samp_remaining);
+      SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
+      bool is_first_chunk = (samp_offset == 0);
+      if (online) {
+        double now = gettime_monotonic();
+        double wait_for = task.send_next_chunk_at - now;
+        if (wait_for > 0) usleep(wait_for * 1e6);
       }
       }
+      asr_client.SendChunk(task.corr_id, is_first_chunk, is_last_chunk,
+                           wave_part.Data(), wave_part.SizeInBytes());
+      task.send_next_chunk_at += chunk_seconds;
+      if (verbose)
+        std::cout << "Sending correlation_id=" << task.corr_id
+                  << " chunk offset=" << num_samp << std::endl;
+
+      task.offset += num_samp;
+      if (!is_last_chunk) next_tasks.push_back(std::move(curr_tasks[itask]));
+    }
 
 
-      curr_tasks.swap(next_tasks);
-      next_tasks.clear();
-      // Showing activity if necessary
-      if (!print_results && !verbose) std::cout << "." << std::flush;
+    curr_tasks.swap(next_tasks);
+    next_tasks.clear();
+    // Showing activity if necessary
+    if (!print_results && !verbose) std::cout << "." << std::flush;
   }
   }
   std::cout << "done" << std::endl;
   std::cout << "done" << std::endl;
   std::cout << "Waiting for all results..." << std::flush;
   std::cout << "Waiting for all results..." << std::flush;
   asr_client.WaitForCallbacks();
   asr_client.WaitForCallbacks();
   std::cout << "done" << std::endl;
   std::cout << "done" << std::endl;
-  asr_client.PrintStats();
+
+  asr_client.PrintStats(online);
+
+  std::unordered_map<ni::CorrelationID, std::string> corr_id_and_keys;
+  for (size_t all_wav_i = 0; all_wav_i < all_wav.size(); ++all_wav_i) {
+    ni::CorrelationID corr_id = static_cast<ni::CorrelationID>(all_wav_i) + 1;
+    corr_id_and_keys.insert({corr_id, all_wav_keys[all_wav_i]});
+  }
+  asr_client.WriteLatticesToFile("ark:|gzip -c > /data/results/lat.cuda-asr.gz",
+                                 corr_id_and_keys);
 
 
   return 0;
   return 0;
 }
 }

+ 7 - 2
Kaldi/SpeechRecognition/model-repo/kaldi_online/config.pbtxt

@@ -46,7 +46,7 @@ string_value:"40"
 {
 {
 key: "max_execution_batch_size"
 key: "max_execution_batch_size"
 value: { 
 value: { 
-string_value:"512"
+string_value:"400"
 }
 }
 }]
 }]
 parameters: {
 parameters: {
@@ -115,7 +115,7 @@ max_sequence_idle_microseconds:5000000
   ]
   ]
 oldest {
 oldest {
 max_candidate_sequences:2200
 max_candidate_sequences:2200
-preferred_batch_size:[256,512]
+preferred_batch_size:[400]
 max_queue_delay_microseconds:1000
 max_queue_delay_microseconds:1000
 }
 }
 },
 },
@@ -133,6 +133,11 @@ input [
   }
   }
 ]
 ]
 output [
 output [
+  {
+    name: "RAW_LATTICE"
+    data_type: TYPE_STRING
+    dims: [ 1 ]
+  },
   {
   {
     name: "TEXT"
     name: "TEXT"
     data_type: TYPE_STRING
     data_type: TYPE_STRING

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 695 - 0
Kaldi/SpeechRecognition/notebooks/Kaldi_TRTIS_inference_offline_demo.ipynb


+ 975 - 0
Kaldi/SpeechRecognition/notebooks/Kaldi_TRTIS_inference_online_demo.ipynb

@@ -0,0 +1,975 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "Gwt7z7qdmTbW"
+   },
+   "outputs": [],
+   "source": [
+    "# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
+    "#\n",
+    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+    "# you may not use this file except in compliance with the License.\n",
+    "# You may obtain a copy of the License at\n",
+    "#\n",
+    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
+    "#\n",
+    "# Unless required by applicable law or agreed to in writing, software\n",
+    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+    "# See the License for the specific language governing permissions and\n",
+    "# limitations under the License.\n",
+    "# =============================================================================="
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "i4NKCp2VmTbn"
+   },
+   "source": [
+    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
+    "\n",
+    "# Kaldi TRTIS Inference Online Demo"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "fW0OKDzvmTbt"
+   },
+   "source": [
+    "## Overview\n",
+    "\n",
+    "\n",
+    "This repository provides a wrapper around the online GPU-accelerated ASR pipeline from the paper [GPU-Accelerated Viterbi Exact Lattice Decoder for Batched Online and Offline Speech Recognition](https://arxiv.org/abs/1910.10032). That work includes a high-performance implementation of a GPU HMM Decoder, a low-latency Neural Net driver, fast Feature Extraction for preprocessing, and new ASR pipelines tailored for GPUs. These different modules have been integrated into the Kaldi ASR framework.\n",
+    "\n",
+    "This repository contains a TensorRT Inference Server custom backend for the Kaldi ASR framework. This custom backend calls the high-performance online GPU pipeline from the Kaldi ASR framework. This TensorRT Inference Server integration provides ease-of-use to Kaldi ASR inference: gRPC streaming server, dynamic sequence batching, and multi-instances support. A client connects to the gRPC server, streams audio by sending chunks to the server, and gets back the inferred text as an answer. More information about the TensorRT Inference Server can be found [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/).  \n",
+    "\n",
+    "\n",
+    "\n",
+    "### Learning objectives\n",
+    "\n",
+    "This notebook demonstrates the steps for carrying out inferencing with the Kaldi TRTIS backend server using a Python gRPC client in an online context, that is, we will stream live audio from a microphone to the inference server and receive the results back.\n",
+    "\n",
+    "## Content\n",
+    "1. [Pre-requisite](#1)\n",
+    "1. [Setup](#2)\n",
+    "1. [Audio helper classes](#3)\n",
+    "1. [Inference](#4)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "aDFrE4eqmTbv"
+   },
+   "source": [
+    "<a id=\"1\"></a>\n",
+    "## 1. Pre-requisite\n",
+    "\n",
+    "\n",
+    "### 1.1 Docker containers\n",
+    "Follow the steps in [README](README.md) to build Kaldi server and client containers.\n",
+    "\n",
+    "### 1.2 Hardware\n",
+    "This notebook can be executed on any CUDA-enabled NVIDIA GPU, although for efficient mixed precision inference, a [Tensor Core NVIDIA GPU](https://www.nvidia.com/en-us/data-center/tensorcore/) is desired (Volta, Turing or newer architectures). "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "k7RLEcKhmTb0"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Thu Mar  5 00:28:21 2020       \r\n",
+      "+-----------------------------------------------------------------------------+\r\n",
+      "| NVIDIA-SMI 440.48.02    Driver Version: 440.48.02    CUDA Version: 10.2     |\r\n",
+      "|-------------------------------+----------------------+----------------------+\r\n",
+      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\r\n",
+      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\r\n",
+      "|===============================+======================+======================|\r\n",
+      "|   0  Quadro GV100        Off  | 00000000:05:00.0 Off |                  Off |\r\n",
+      "| 32%   42C    P2    28W / 250W |  17706MiB / 32506MiB |      3%      Default |\r\n",
+      "+-------------------------------+----------------------+----------------------+\r\n",
+      "                                                                               \r\n",
+      "+-----------------------------------------------------------------------------+\r\n",
+      "| Processes:                                                       GPU Memory |\r\n",
+      "|  GPU       PID   Type   Process name                             Usage      |\r\n",
+      "|=============================================================================|\r\n",
+      "+-----------------------------------------------------------------------------+\r\n"
+     ]
+    }
+   ],
+   "source": [
+    "!nvidia-smi"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "EQAIszkxmTcT"
+   },
+   "source": [
+    "This notebook also requires access to a microphone. "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<a id=\"2\"></a>\n",
+    "## 2 Setup \n",
+    "### Import libraries and parameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import argparse\n",
+    "import numpy as np\n",
+    "import os\n",
+    "import sys\n",
+    "from builtins import range\n",
+    "from functools import partial\n",
+    "import soundfile\n",
+    "import pyaudio as pa\n",
+    "import soundfile\n",
+    "import librosa\n",
+    "\n",
+    "import grpc\n",
+    "from tensorrtserver.api import api_pb2\n",
+    "from tensorrtserver.api import grpc_service_pb2\n",
+    "from tensorrtserver.api import grpc_service_pb2_grpc\n",
+    "import tensorrtserver.api.model_config_pb2 as model_config"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "parser = argparse.ArgumentParser()\n",
+    "parser.add_argument('-f', '--file', help='Path for input file. First line should contain number of lines to search in')\n",
+    "\n",
+    "parser.add_argument('-v', '--verbose', action=\"store_true\", required=False, default=False,\n",
+    "                    help='Enable verbose output')\n",
+    "parser.add_argument('-a', '--async', dest=\"async_set\", action=\"store_true\", required=False,\n",
+    "                    default=False, help='Use asynchronous inference API')\n",
+    "parser.add_argument('--streaming', action=\"store_true\", required=False, default=False,\n",
+    "                    help='Use streaming inference API')\n",
+    "parser.add_argument('-m', '--model-name', type=str, required=False, default='kaldi_online' ,\n",
+    "                    help='Name of model')\n",
+    "parser.add_argument('-x', '--model-version', type=int, required=False, default=1,\n",
+    "                    help='Version of model. Default is to use latest version.')\n",
+    "parser.add_argument('-b', '--batch-size', type=int, required=False, default=1,\n",
+    "                    help='Batch size. Default is 1.')\n",
+    "parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001',\n",
+    "                    help='Inference server URL. Default is localhost:8001.')\n",
+    "parser.add_argument('--chunk_duration', type=float, required=False,\n",
+    "                    default=0.51,\n",
+    "                    help=\"duration of the audio chunk for streaming \"\n",
+    "                            \"recognition, in seconds\")\n",
+    "parser.add_argument('--input_device_id', type=int, required=False,\n",
+    "                    default=-1, help='Input device id to use to capture audio')\n",
+    "parser.add_argument('--sample_rate', type=int, required=False,\n",
+    "                    default=16000, help='Sample rate.')\n",
+    "FLAGS = parser.parse_args()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Checking server status\n",
+    "\n",
+    "We first query the status of the server. The target model is 'kaldi_online'. A successful deployment of the Kaldi TRTIS server should result in output similar to the below.\n",
+    "\n",
+    "```\n",
+    "request_status {\n",
+    "  code: SUCCESS\n",
+    "  server_id: \"inference:0\"\n",
+    "  request_id: 17514\n",
+    "}\n",
+    "server_status {\n",
+    "  id: \"inference:0\"\n",
+    "  version: \"1.9.0\"\n",
+    "  uptime_ns: 14179155408971\n",
+    "  model_status {\n",
+    "    key: \"kaldi_online\"\n",
+    "...\n",
+    "```"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "request_status {\n",
+      "  code: SUCCESS\n",
+      "  server_id: \"inference:0\"\n",
+      "  request_id: 6234\n",
+      "}\n",
+      "server_status {\n",
+      "  id: \"inference:0\"\n",
+      "  version: \"1.9.0\"\n",
+      "  uptime_ns: 4061941924008\n",
+      "  model_status {\n",
+      "    key: \"kaldi_online\"\n",
+      "    value {\n",
+      "      config {\n",
+      "        name: \"kaldi_online\"\n",
+      "        platform: \"custom\"\n",
+      "        version_policy {\n",
+      "          latest {\n",
+      "            num_versions: 1\n",
+      "          }\n",
+      "        }\n",
+      "        max_batch_size: 2200\n",
+      "        input {\n",
+      "          name: \"WAV_DATA\"\n",
+      "          data_type: TYPE_FP32\n",
+      "          dims: 8160\n",
+      "        }\n",
+      "        input {\n",
+      "          name: \"WAV_DATA_DIM\"\n",
+      "          data_type: TYPE_INT32\n",
+      "          dims: 1\n",
+      "        }\n",
+      "        output {\n",
+      "          name: \"TEXT\"\n",
+      "          data_type: TYPE_STRING\n",
+      "          dims: 1\n",
+      "        }\n",
+      "        instance_group {\n",
+      "          name: \"kaldi_online_0\"\n",
+      "          count: 2\n",
+      "          gpus: 0\n",
+      "          kind: KIND_GPU\n",
+      "        }\n",
+      "        default_model_filename: \"libkaldi-trtisbackend.so\"\n",
+      "        sequence_batching {\n",
+      "          max_sequence_idle_microseconds: 5000000\n",
+      "          control_input {\n",
+      "            name: \"START\"\n",
+      "            control {\n",
+      "              int32_false_true: 0\n",
+      "              int32_false_true: 1\n",
+      "            }\n",
+      "          }\n",
+      "          control_input {\n",
+      "            name: \"READY\"\n",
+      "            control {\n",
+      "              kind: CONTROL_SEQUENCE_READY\n",
+      "              int32_false_true: 0\n",
+      "              int32_false_true: 1\n",
+      "            }\n",
+      "          }\n",
+      "          control_input {\n",
+      "            name: \"END\"\n",
+      "            control {\n",
+      "              kind: CONTROL_SEQUENCE_END\n",
+      "              int32_false_true: 0\n",
+      "              int32_false_true: 1\n",
+      "            }\n",
+      "          }\n",
+      "          control_input {\n",
+      "            name: \"CORRID\"\n",
+      "            control {\n",
+      "              kind: CONTROL_SEQUENCE_CORRID\n",
+      "              data_type: TYPE_UINT64\n",
+      "            }\n",
+      "          }\n",
+      "          oldest {\n",
+      "            max_candidate_sequences: 2200\n",
+      "            preferred_batch_size: 256\n",
+      "            preferred_batch_size: 512\n",
+      "            max_queue_delay_microseconds: 1000\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"acoustic_scale\"\n",
+      "          value {\n",
+      "            string_value: \"1.0\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"beam\"\n",
+      "          value {\n",
+      "            string_value: \"10\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"frame_subsampling_factor\"\n",
+      "          value {\n",
+      "            string_value: \"3\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"fst_rxfilename\"\n",
+      "          value {\n",
+      "            string_value: \"/data/models/LibriSpeech/HCLG.fst\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"ivector_filename\"\n",
+      "          value {\n",
+      "            string_value: \"/data/models/LibriSpeech/conf/ivector_extractor.conf\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"lattice_beam\"\n",
+      "          value {\n",
+      "            string_value: \"7\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"max_active\"\n",
+      "          value {\n",
+      "            string_value: \"10000\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"max_execution_batch_size\"\n",
+      "          value {\n",
+      "            string_value: \"512\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"mfcc_filename\"\n",
+      "          value {\n",
+      "            string_value: \"/data/models/LibriSpeech/conf/mfcc.conf\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"nnet3_rxfilename\"\n",
+      "          value {\n",
+      "            string_value: \"/data/models/LibriSpeech/final.mdl\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"num_worker_threads\"\n",
+      "          value {\n",
+      "            string_value: \"40\"\n",
+      "          }\n",
+      "        }\n",
+      "        parameters {\n",
+      "          key: \"word_syms_rxfilename\"\n",
+      "          value {\n",
+      "            string_value: \"/data/models/LibriSpeech/words.txt\"\n",
+      "          }\n",
+      "        }\n",
+      "      }\n",
+      "      version_status {\n",
+      "        key: 1\n",
+      "        value {\n",
+      "          ready_state: MODEL_READY\n",
+      "          infer_stats {\n",
+      "            key: 1\n",
+      "            value {\n",
+      "              success {\n",
+      "                count: 6913\n",
+      "                total_time_ns: 233146745257\n",
+      "              }\n",
+      "              compute {\n",
+      "                count: 6913\n",
+      "                total_time_ns: 225589026013\n",
+      "              }\n",
+      "              queue {\n",
+      "                count: 6913\n",
+      "                total_time_ns: 7398387984\n",
+      "              }\n",
+      "            }\n",
+      "          }\n",
+      "          model_execution_count: 6913\n",
+      "          model_inference_count: 6913\n",
+      "          ready_state_reason {\n",
+      "          }\n",
+      "          last_inference_timestamp_milliseconds: 13619175935932035456\n",
+      "        }\n",
+      "      }\n",
+      "    }\n",
+      "  }\n",
+      "  ready_state: SERVER_READY\n",
+      "}\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Create gRPC stub for communicating with the server\n",
+    "channel = grpc.insecure_channel(FLAGS.url)\n",
+    "grpc_stub = grpc_service_pb2_grpc.GRPCServiceStub(channel)\n",
+    "\n",
+    "# Prepare request for Status gRPC\n",
+    "request = grpc_service_pb2.StatusRequest(model_name=FLAGS.model_name)\n",
+    "# Call and receive response from Status gRPC\n",
+    "response = grpc_stub.Status(request)\n",
+    "\n",
+    "print(response)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Testing microphone\n",
+    "\n",
+    "We next identify the input devices in the system. You will need to select a relevant input device amongst the ones listed. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Input Devices:\n",
+      "0: HDA Intel PCH: ALC1150 Analog (hw:0,0)\n",
+      "1: HDA Intel PCH: ALC1150 Digital (hw:0,1)\n",
+      "2: HDA Intel PCH: ALC1150 Alt Analog (hw:0,2)\n",
+      "3: HD Pro Webcam C920: USB Audio (hw:1,0)\n",
+      "4: HDA NVidia: HDMI 0 (hw:2,3)\n",
+      "5: HDA NVidia: HDMI 2 (hw:2,8)\n",
+      "6: HDA NVidia: HDMI 3 (hw:2,9)\n",
+      "7: sysdefault\n",
+      "8: front\n",
+      "9: surround21\n",
+      "10: surround40\n",
+      "11: surround41\n",
+      "12: surround50\n",
+      "13: surround51\n",
+      "14: surround71\n",
+      "15: iec958\n",
+      "16: spdif\n",
+      "17: default\n",
+      "18: dmix\n",
+      "Enter device id to use: 3\n"
+     ]
+    }
+   ],
+   "source": [
+    "import pyaudio\n",
+    "import wave\n",
+    "\n",
+    "p = pyaudio.PyAudio()  # Create an interface to PortAudio\n",
+    "\n",
+    "device_info = p.get_host_api_info_by_index(0)\n",
+    "num_devices = device_info.get('deviceCount')\n",
+    "\n",
+    "devices = {}\n",
+    "for i in range(0, num_devices):\n",
+    "    #if (p.get_device_info_by_host_api_device_index(0, i).get(\n",
+    "    #    'maxInputChannels')) > 0:\n",
+    "        devices[i] = p.get_device_info_by_host_api_device_index(\n",
+    "            0, i)\n",
+    "\n",
+    "if (len(devices) == 0):\n",
+    "    raise RuntimeError(\"Cannot find any valid input devices\")\n",
+    "\n",
+    "\n",
+    "print(\"\\nInput Devices:\")\n",
+    "for id, info in devices.items():\n",
+    "    print(\"{}: {}\".format(id,info.get(\"name\")))\n",
+    "input_device_id = int(input(\"Enter device id to use: \"))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We then employ the selected device, record from it and play back to verify that everything is in order."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Device info:\n",
+      "{   'defaultHighInputLatency': 0.048,\n",
+      "    'defaultHighOutputLatency': -1.0,\n",
+      "    'defaultLowInputLatency': 0.01196875,\n",
+      "    'defaultLowOutputLatency': -1.0,\n",
+      "    'defaultSampleRate': 32000.0,\n",
+      "    'hostApi': 0,\n",
+      "    'index': 3,\n",
+      "    'maxInputChannels': 2,\n",
+      "    'maxOutputChannels': 0,\n",
+      "    'name': 'HD Pro Webcam C920: USB Audio (hw:1,0)',\n",
+      "    'structVersion': 2}\n",
+      "Recording\n",
+      "Finished recording\n"
+     ]
+    }
+   ],
+   "source": [
+    "import pprint\n",
+    "pp = pprint.PrettyPrinter(indent=4)\n",
+    "    \n",
+    "print(\"Device info:\")\n",
+    "devinfo = p.get_device_info_by_index(input_device_id)  # Or whatever device you care about.\n",
+    "pp.pprint(devinfo)\n",
+    "\n",
+    "chunk = 1024  # Record in chunks of 1024 samples\n",
+    "sample_format = pyaudio.paInt16  # 16 bits per sample\n",
+    "channels = 1\n",
+    "fs = devinfo['defaultSampleRate']  # Record at device default sampling rate\n",
+    "seconds = 3\n",
+    "filename = \"test.wav\"\n",
+    "\n",
+    "print('Recording')\n",
+    "\n",
+    "stream = p.open(format=sample_format,\n",
+    "                channels=channels,\n",
+    "                rate=int(devinfo[\"defaultSampleRate\"]),\n",
+    "                frames_per_buffer=chunk,\n",
+    "                input=True,\n",
+    "                input_device_index=input_device_id)\n",
+    "\n",
+    "frames = []  # Initialize array to store frames\n",
+    "\n",
+    "# Store data in chunks for 3 seconds\n",
+    "for i in range(0, int(fs / chunk * seconds)):\n",
+    "    data = stream.read(chunk)\n",
+    "    frames.append(data)\n",
+    "\n",
+    "# Stop and close the stream \n",
+    "stream.stop_stream()\n",
+    "stream.close()\n",
+    "# Terminate the PortAudio interface\n",
+    "# p.terminate()\n",
+    "\n",
+    "print('Finished recording')\n",
+    "\n",
+    "# Save the recorded data as a WAV file\n",
+    "wf = wave.open(filename, 'wb')\n",
+    "wf.setnchannels(channels)\n",
+    "wf.setsampwidth(p.get_sample_size(sample_format))\n",
+    "wf.setframerate(fs)\n",
+    "wf.writeframes(b''.join(frames))\n",
+    "wf.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import IPython.display as ipd\n",
+    "ipd.Audio(filename)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "RL8d9IwzmTcV"
+   },
+   "source": [
+    "<a id=\"3\"></a>\n",
+    "## 3. Audio helper classes"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "o6wayGf1mTcX"
+   },
+   "source": [
+    "Next, we define some helper classes for pre-processing audio. The below AudioSegment class takes audio signal and converts the sampling rate to that required by the Kaldi ASR model, which is 16000Hz by default.\n",
+    "\n",
+    "Note:  For historical reasons, Kaldi expects waveforms in the range (2^15-1)x[-1, 1], not the usual default DSP range [-1, 1]. Therefore, we scale the audio signal by a factor of (2^15-1)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "WAV_SCALE_FACTOR = 2**15-1\n",
+    "\n",
+    "class AudioSegment(object):\n",
+    "    \"\"\"Monaural audio segment abstraction.\n",
+    "    :param samples: Audio samples [num_samples x num_channels].\n",
+    "    :type samples: ndarray.float32\n",
+    "    :param sample_rate: Audio sample rate.\n",
+    "    :type sample_rate: int\n",
+    "    :raises TypeError: If the sample data type is not float or int.\n",
+    "    \"\"\"\n",
+    "\n",
+    "    def __init__(self, samples, sample_rate, target_sr=16000, trim=False,\n",
+    "                 trim_db=60):\n",
+    "        \"\"\"Create audio segment from samples.\n",
+    "        Samples are convert float32 internally, with int scaled to [-1, 1].\n",
+    "        \"\"\"\n",
+    "        samples = self._convert_samples_to_float32(samples)\n",
+    "        if target_sr is not None and target_sr != sample_rate:\n",
+    "            samples = librosa.core.resample(samples, sample_rate, target_sr)\n",
+    "            sample_rate = target_sr\n",
+    "        if trim:\n",
+    "            samples, _ = librosa.effects.trim(samples, trim_db)\n",
+    "        self._samples = samples\n",
+    "        self._sample_rate = sample_rate\n",
+    "        if self._samples.ndim >= 2:\n",
+    "            self._samples = np.mean(self._samples, 1)\n",
+    "\n",
+    "    @staticmethod\n",
+    "    def _convert_samples_to_float32(samples):\n",
+    "        \"\"\"Convert sample type to float32.\n",
+    "        Audio sample type is usually integer or float-point.\n",
+    "        Integers will be scaled to [-1, 1] in float32.\n",
+    "        \"\"\"\n",
+    "        float32_samples = samples.astype('float32')\n",
+    "        if samples.dtype in np.sctypes['int']:\n",
+    "            bits = np.iinfo(samples.dtype).bits\n",
+    "            float32_samples *= (1. / ((2 ** (bits - 1)) - 1))\n",
+    "        elif samples.dtype in np.sctypes['float']:\n",
+    "            pass\n",
+    "        else:\n",
+    "            raise TypeError(\"Unsupported sample type: %s.\" % samples.dtype)\n",
+    "        return WAV_SCALE_FACTOR * float32_samples\n",
+    "\n",
+    "    @classmethod\n",
+    "    def from_file(cls, filename, target_sr=16000, offset=0, duration=0,\n",
+    "                 min_duration=0, trim=False):\n",
+    "        \"\"\"\n",
+    "        Load a file supported by librosa and return as an AudioSegment.\n",
+    "        :param filename: path of file to load\n",
+    "        :param target_sr: the desired sample rate\n",
+    "        :param int_values: if true, load samples as 32-bit integers\n",
+    "        :param offset: offset in seconds when loading audio\n",
+    "        :param duration: duration in seconds when loading audio\n",
+    "        :return: numpy array of samples\n",
+    "        \"\"\"\n",
+    "        with sf.SoundFile(filename, 'r') as f:\n",
+    "            dtype_options = {'PCM_16': 'int16', 'PCM_32': 'int32', 'FLOAT': 'float32'}\n",
+    "            dtype_file = f.subtype\n",
+    "            if dtype_file in dtype_options:\n",
+    "                dtype = dtype_options[dtype_file]\n",
+    "            else:\n",
+    "                dtype = 'float32'\n",
+    "            sample_rate = f.samplerate\n",
+    "            if offset > 0:\n",
+    "                f.seek(int(offset * sample_rate))\n",
+    "            if duration > 0:\n",
+    "                samples = f.read(int(duration * sample_rate), dtype=dtype)\n",
+    "            else:\n",
+    "                samples = f.read(dtype=dtype)\n",
+    "\n",
+    "        num_zero_pad = int(target_sr * min_duration - samples.shape[0])\n",
+    "        if num_zero_pad > 0:\n",
+    "            samples = np.pad(samples, [0, num_zero_pad], mode='constant')\n",
+    "\n",
+    "        samples = samples.transpose()\n",
+    "        return cls(samples, sample_rate, target_sr=target_sr, trim=trim)\n",
+    "\n",
+    "    @property\n",
+    "    def samples(self):\n",
+    "        return self._samples.copy()\n",
+    "\n",
+    "    @property\n",
+    "    def sample_rate(self):\n",
+    "        return self._sample_rate"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<a id=\"4\"></a>\n",
+    "## Inference\n",
+    "\n",
+    "We first create an inference context object that connects to the Kaldi TRTIS servier via a gPRC connection.\n",
+    "\n",
+    "The server expects chunks of audio each containing up to input.WAV_DATA.dims samples (default: 8160). Per default, this corresponds to 510ms of audio per chunk (i.e. 16000Hz sampling rate). The last chunk can send a partial chunk smaller than this maximum value."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from tensorrtserver.api import *\n",
+    "protocol = ProtocolType.from_str(\"grpc\")\n",
+    "\n",
+    "CORRELATION_ID = 11101\n",
+    "ctx = InferContext(FLAGS.url, protocol, FLAGS.model_name, FLAGS.model_version,\n",
+    "                    correlation_id=CORRELATION_ID, verbose=True,\n",
+    "                    streaming=False)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Next, we take chunks of audio (each 510ms in duration, containing 8160 samples) from the microphone and stream them sequentially to the Kaldi server. The server processes each chunk as soon as it is received. \n",
+    "\n",
+    "Unlike data from a .wav file, as we take the data continuoulsy from the mic, there is no `end` marker. Therefore, we receive the result once every 10 chunks. Note that the server will reset it status once the result is sent out.   "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class TranscribeFromMicrophone:\n",
+    "\n",
+    "    def __init__(self,input_device_id, target_sr, chunk_duration):\n",
+    "\n",
+    "        self.recording_state = \"init\"\n",
+    "        self.target_sr  = target_sr\n",
+    "        self.chunk_duration = chunk_duration\n",
+    "\n",
+    "        self.p = pa.PyAudio()\n",
+    "\n",
+    "        device_info = self.p.get_host_api_info_by_index(0)\n",
+    "        num_devices = device_info.get('deviceCount')\n",
+    "        devices = {}\n",
+    "        for i in range(0, num_devices):\n",
+    "            if (self.p.get_device_info_by_host_api_device_index(0, i).get(\n",
+    "                'maxInputChannels')) > 0:\n",
+    "                devices[i] = self.p.get_device_info_by_host_api_device_index(\n",
+    "                    0, i)\n",
+    "\n",
+    "        if (len(devices) == 0):\n",
+    "            raise RuntimeError(\"Cannot find any valid input devices\")\n",
+    "\n",
+    "        if input_device_id is None or input_device_id not in \\\n",
+    "            devices.keys():\n",
+    "            print(\"\\nInput Devices:\")\n",
+    "            for id, info in devices.items():\n",
+    "                print(\"{}: {}\".format(id,info.get(\"name\")))\n",
+    "            input_device_id = int(input(\"Enter device id to use: \"))\n",
+    "\n",
+    "        self.input_device_id = input_device_id\n",
+    "        devinfo = self.p.get_device_info_by_index(input_device_id)\n",
+    "        self.device_default_sr = int(devinfo['defaultSampleRate'])\n",
+    "        print(\"Device sample rate: %d\" % self.device_default_sr)\n",
+    "\n",
+    "    def transcribe_audio(self, streaming=True):\n",
+    "        ctx = InferContext(FLAGS.url, protocol, FLAGS.model_name, FLAGS.model_version,\n",
+    "                    correlation_id=CORRELATION_ID, verbose=True,\n",
+    "                    streaming=False)\n",
+    "        \n",
+    "        chunk_size = int(self.chunk_duration*self.device_default_sr)\n",
+    "        self.recording_state = \"init\"\n",
+    "\n",
+    "        def keyboard_listener():\n",
+    "            input(\"**********Press Enter to start and end transcribing...**********\")\n",
+    "            self.recording_state = \"capture\"\n",
+    "            print(\"Recording...\")\n",
+    "            \n",
+    "            input(\"\")\n",
+    "            self.recording_state = \"release\"\n",
+    "\n",
+    "        listener = threading.Thread(target=keyboard_listener)\n",
+    "        listener.start()\n",
+    "\n",
+    "        start = True\n",
+    "        print(\"starting....\")\n",
+    "        \n",
+    "        stream_initialized = False\n",
+    "        audio_signal = 0\n",
+    "        audio_segment = 0\n",
+    "        end = False\n",
+    "        \n",
+    "        cnt = 0\n",
+    "        MAX_CHUNKS = 10\n",
+    "        while self.recording_state != \"release\":\n",
+    "            try:\n",
+    "                if self.recording_state == \"capture\":\n",
+    "\n",
+    "                    if not stream_initialized:\n",
+    "                        stream = self.p.open(\n",
+    "                            format=pa.paInt16,\n",
+    "                            channels=1,\n",
+    "                            rate=self.device_default_sr,\n",
+    "                            input=True,\n",
+    "                            input_device_index=self.input_device_id,\n",
+    "                            frames_per_buffer=chunk_size)\n",
+    "                        stream_initialized = True\n",
+    "\n",
+    "                    # Read an audio chunk from microphone\n",
+    "                    audio_signal = stream.read(chunk_size, exception_on_overflow = False)\n",
+    "                    if self.recording_state == \"release\":\n",
+    "                      break\n",
+    "                      end = True\n",
+    "                    audio_signal = np.frombuffer(audio_signal,dtype=np.int16)\n",
+    "                    audio_segment = AudioSegment(audio_signal,\n",
+    "                                                              self.device_default_sr,\n",
+    "                                                              self.target_sr)\n",
+    "                    \n",
+    "                    if cnt == MAX_CHUNKS:\n",
+    "                        end = True\n",
+    "                    if cnt > 1:\n",
+    "                        start = False\n",
+    "                        \n",
+    "                    # Inference\n",
+    "                    flags = InferRequestHeader.FLAG_NONE\n",
+    "                    x = (audio_segment.samples, self.target_sr, start, end)\n",
+    "                    if x[2]:\n",
+    "                        flags = flags | InferRequestHeader.FLAG_SEQUENCE_START\n",
+    "                    if x[3]:\n",
+    "                        flags = flags | InferRequestHeader.FLAG_SEQUENCE_END\n",
+    "                    if not end:\n",
+    "                        ctx.run({'WAV_DATA' : (x[0],),\n",
+    "                                 'WAV_DATA_DIM' : (np.full(shape=1, fill_value=len(x[0]), dtype=np.int32),)},\n",
+    "                                {},\n",
+    "                                batch_size=1,\n",
+    "                                flags=flags,\n",
+    "                                corr_id=CORRELATION_ID)\n",
+    "                    else:\n",
+    "                        res = ctx.run({'WAV_DATA' : (x[0],),\n",
+    "                                       'WAV_DATA_DIM' : (np.full(shape=1, fill_value=len(x[0]), dtype=np.int32),)},\n",
+    "                                      { 'TEXT' : InferContext.ResultFormat.RAW },\n",
+    "                                      batch_size=1,\n",
+    "                                      flags=flags,\n",
+    "                                      corr_id=CORRELATION_ID)\n",
+    "                        print(\"\".join([x.decode('utf-8') for x in res['TEXT'][0]]))\n",
+    "                    \n",
+    "                    if cnt == MAX_CHUNKS: # reset server\n",
+    "                        start = True\n",
+    "                        end = False\n",
+    "                        cnt = 0\n",
+    "                    \n",
+    "                    cnt += 1\n",
+    "                    sys.stdout.write(\"\\r\" + \".\"*cnt)\n",
+    "                    sys.stdout.flush()\n",
+    "                    \n",
+    "            except Exception as e:\n",
+    "                print(e)\n",
+    "                break\n",
+    "\n",
+    "        stream.close()\n",
+    "        self.p.terminate()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Device sample rate: 32000\n"
+     ]
+    }
+   ],
+   "source": [
+    "transcriber = TranscribeFromMicrophone(input_device_id,\n",
+    "    target_sr=FLAGS.sample_rate,\n",
+    "    chunk_duration=FLAGS.chunk_duration)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "After executing the below cell, upon pressing ENTER, the mic will start recording chunks of audio from the specified mic and stream them continuously to the server. After every 10 chunks, the client takes and display the results, while the status of the server is reset, i.e., it treats the next chunk as the start of a fresh new request. \n",
+    "When pressing ENTER again, the client stops.\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "transcriber.transcribe_audio()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "g8MxXY5GmTc8"
+   },
+   "source": [
+    "# Conclusion\n",
+    "\n",
+    "In this notebook, we have walked through the complete process of preparing the audio data from a microphone and carry out inference with the Kaldi ASR model.\n",
+    "\n",
+    "## What's next\n",
+    "Now it's time to try the Kaldi ASR model on your own data. The online client can also be further improved, for example, by detecting natural breaks in the input stream (e.g., silence) to break sentence more properly. \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "249yGNLmmTc_"
+   },
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "colab": {
+   "include_colab_link": true,
+   "name": "TensorFlow_UNet_Industrial_Colab_train_and_inference.ipynb",
+   "provenance": []
+  },
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.9"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}

+ 75 - 0
Kaldi/SpeechRecognition/notebooks/README.md

@@ -0,0 +1,75 @@
+```
+# Licensed under the Apache License, Version 2.0 (the "License")
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and limitations under the License.
+```
+<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">
+
+# Kaldi  inference demo
+
+## 1. Overview
+
+This folder contains two notebooks demonstrating the steps for carrying out inferencing with the Kaldi TRTIS backend server using a Python gRPC client.
+ 
+- [Offline](Kaldi_TRTIS_inference_offline_demo.ipynb): we will stream pre-recorded .wav files to the inference server and receive the results back.
+- [Online](Kaldi_TRTIS_inference_online_demo.ipynb): we will stream live audio stream from a microphone to the inference server and receive the results back.
+
+## 2. Quick Start Guide
+
+First, clone the repository:
+
+```
+git clone https://github.com/NVIDIA/DeepLearningExamples.git
+cd DeepLearningExamples/Kaldi/SpeechRecognition
+```
+Next, build the NVIDIA Kaldi TRTIS container:
+
+```
+scripts/docker/build.sh
+```
+
+Then download the model and some test data set with:
+```
+scripts/docker/launch_download.sh
+```
+Next, launch the TRTIS container with:
+```
+scripts/docker/launch_server.sh
+```
+After this step, we should have a TRTIS server ready to serve ASR inference requests.
+
+The next step is to build a TRTIS client container:
+
+```bash
+docker build -t kaldi_notebook_client -f Dockerfile.notebook .
+```
+
+Start the client container with:
+
+```bash
+docker run -it --rm --net=host --device /dev/snd:/dev/snd -v $PWD:/Kaldi kaldi_notebook_client
+```
+
+Within the client container, start Jupyter notebook server:
+
+```bash
+cd /Kaldi
+jupyter notebook --ip=0.0.0.0 --allow-root
+```
+
+And navigate a web browser to the IP address or hostname of the host machine
+at port `8888`:
+
+```
+http://[host machine]:8888
+```
+
+Use the token listed in the output from running the `jupyter` command to log
+in, for example:
+
+```
+http://[host machine]:8888/?token=aae96ae9387cd28151868fee318c3b3581a2d794f3b25c6b
+```

+ 24 - 0
Kaldi/SpeechRecognition/scripts/compute_wer.sh

@@ -0,0 +1,24 @@
+#!/bin/bash
+
+model_path="/data/models/LibriSpeech"
+librispeech_path="/data/datasets/LibriSpeech/test_clean"
+result_path="/data/results"
+
+# Correctness
+
+cat $model_path/words.txt | tr '[:upper:]' '[:lower:]' > $result_path/words.txt
+cat $librispeech_path/$test_set/text | tr '[:upper:]' '[:lower:]' > $result_path/text
+oovtok=$(cat $result_path/words.txt | grep "<unk>" | awk '{print $2}')
+/opt/kaldi/egs/wsj/s5/utils/sym2int.pl --map-oov $oovtok -f 2- $result_path/words.txt $result_path/text > $result_path/text_ints 2> /dev/null
+
+
+# convert lattice to transcript
+/opt/kaldi/src/latbin/lattice-best-path \
+	"ark:gunzip -c $result_path/lat.cuda-asr.gz |"\
+	"ark,t:|gzip -c > $result_path/trans.cuda-asr.gz" 2> /dev/null
+
+# calculate wer
+/opt/kaldi/src/bin/compute-wer --mode=present \
+	"ark:$result_path/text_ints" \
+	"ark:gunzip -c $result_path/trans.cuda-asr.gz |" 2>&1
+

+ 1 - 1
Kaldi/SpeechRecognition/scripts/docker/launch_client.sh

@@ -19,4 +19,4 @@ docker run --rm -it \
     --ulimit memlock=-1 \
     --ulimit memlock=-1 \
     --ulimit stack=67108864 \
     --ulimit stack=67108864 \
     -v $PWD/data:/data \
     -v $PWD/data:/data \
-    trtis_kaldi_client install/bin/kaldi_asr_parallel_client $@
+    trtis_kaldi_client /workspace/scripts/docker/run_client.sh $@

+ 13 - 0
Kaldi/SpeechRecognition/scripts/docker/run_client.sh

@@ -0,0 +1,13 @@
+#!/bin/bash
+set -e 
+results_dir=/data/results
+
+if [ -d "$results_dir" ]
+then
+	rm -rf $results_dir
+fi	
+mkdir $results_dir
+install/bin/kaldi_asr_parallel_client $@
+echo "Computing WER..."
+/workspace/scripts/compute_wer.sh
+rm -rf $results_dir

+ 1 - 1
Kaldi/SpeechRecognition/scripts/run_inference_all_t4.sh

@@ -16,7 +16,7 @@
 set -e
 set -e
 
 
 if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then
 if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then
-	printf "\nThe TensorRT Inference Server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
+	printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
 	exit 1
 	exit 1
 fi
 fi
 
 

+ 1 - 1
Kaldi/SpeechRecognition/scripts/run_inference_all_v100.sh

@@ -16,7 +16,7 @@
 set -e
 set -e
 
 
 if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then
 if [[ "$(docker ps | grep trtis_kaldi_server | wc -l)" == "0" ]]; then
-	printf "\nThe TensorRT Inference Server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
+	printf "\nThe Triton server is currently not running. Please run scripts/docker/launch_server.sh\n\n"
 	exit 1
 	exit 1
 fi
 fi
 
 

+ 67 - 45
Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.cc

@@ -104,6 +104,7 @@ int Context::ReadModelParameters() {
 int Context::InitializeKaldiPipeline() {
 int Context::InitializeKaldiPipeline() {
   batch_corr_ids_.reserve(max_batch_size_);
   batch_corr_ids_.reserve(max_batch_size_);
   batch_wave_samples_.reserve(max_batch_size_);
   batch_wave_samples_.reserve(max_batch_size_);
+  batch_is_first_chunk_.reserve(max_batch_size_);
   batch_is_last_chunk_.reserve(max_batch_size_);
   batch_is_last_chunk_.reserve(max_batch_size_);
   wave_byte_buffers_.resize(max_batch_size_);
   wave_byte_buffers_.resize(max_batch_size_);
   output_shape_ = {1, 1};
   output_shape_ = {1, 1};
@@ -181,10 +182,17 @@ int Context::Execute(const uint32_t payload_cnt, CustomPayload* payloads,
 
 
     kaldi::SubVector<BaseFloat> wave_part(wave_buffer, dim);
     kaldi::SubVector<BaseFloat> wave_part(wave_buffer, dim);
     // Initialize corr_id if first chunk
     // Initialize corr_id if first chunk
-    if (start) cuda_pipeline_->InitCorrID(corr_id);
+    if (start) {
+      if (!cuda_pipeline_->TryInitCorrID(corr_id)) {
+        printf("ERR %i \n", __LINE__);
+        // TODO add error code
+        continue;
+      }
+    }
     // Add to batch
     // Add to batch
     batch_corr_ids_.push_back(corr_id);
     batch_corr_ids_.push_back(corr_id);
     batch_wave_samples_.push_back(wave_part);
     batch_wave_samples_.push_back(wave_part);
+    batch_is_first_chunk_.push_back(start);
     batch_is_last_chunk_.push_back(end);
     batch_is_last_chunk_.push_back(end);
 
 
     if (end) {
     if (end) {
@@ -192,9 +200,7 @@ int Context::Execute(const uint32_t payload_cnt, CustomPayload* payloads,
       cuda_pipeline_->SetLatticeCallback(
       cuda_pipeline_->SetLatticeCallback(
           corr_id, [this, &output_fn, &payloads, pidx,
           corr_id, [this, &output_fn, &payloads, pidx,
                     corr_id](kaldi::CompactLattice& clat) {
                     corr_id](kaldi::CompactLattice& clat) {
-            std::string output;
-            LatticeToString(*word_syms_, clat, &output);
-            SetOutputTensor(output, output_fn, payloads[pidx]);
+            SetOutputs(clat, output_fn, payloads[pidx]);
           });
           });
     }
     }
   }
   }
@@ -206,9 +212,10 @@ int Context::Execute(const uint32_t payload_cnt, CustomPayload* payloads,
 int Context::FlushBatch() {
 int Context::FlushBatch() {
   if (!batch_corr_ids_.empty()) {
   if (!batch_corr_ids_.empty()) {
     cuda_pipeline_->DecodeBatch(batch_corr_ids_, batch_wave_samples_,
     cuda_pipeline_->DecodeBatch(batch_corr_ids_, batch_wave_samples_,
-                                batch_is_last_chunk_);
+                                batch_is_first_chunk_, batch_is_last_chunk_);
     batch_corr_ids_.clear();
     batch_corr_ids_.clear();
     batch_wave_samples_.clear();
     batch_wave_samples_.clear();
+    batch_is_first_chunk_.clear();
     batch_is_last_chunk_.clear();
     batch_is_last_chunk_.clear();
   }
   }
 }
 }
@@ -254,20 +261,21 @@ int Context::InputOutputSanityCheck() {
     return kInputName;
     return kInputName;
   }
   }
 
 
-  if (model_config_.output_size() != 1) {
-    return kInputOutput;
-  }
-  if ((model_config_.output(0).dims().size() != 1) ||
-      (model_config_.output(0).dims(0) != 1)) {
-    return kInputOutput;
-  }
-  if (model_config_.output(0).data_type() != DataType::TYPE_STRING) {
-    return kInputOutputDataType;
-  }
-  if (model_config_.output(0).name() != "TEXT") {
-    return kOutputName;
+  if (model_config_.output_size() != 2) return kInputOutput;
+
+  for (int ioutput = 0; ioutput < 2; ++ioutput) {
+    if ((model_config_.output(ioutput).dims().size() != 1) ||
+        (model_config_.output(ioutput).dims(0) != 1)) {
+      return kInputOutput;
+    }
+    if (model_config_.output(ioutput).data_type() != DataType::TYPE_STRING) {
+      return kInputOutputDataType;
+    }
   }
   }
 
 
+  if (model_config_.output(0).name() != "RAW_LATTICE") return kOutputName;
+  if (model_config_.output(1).name() != "TEXT") return kOutputName;
+
   return kSuccess;
   return kSuccess;
 }
 }
 
 
@@ -316,34 +324,48 @@ int Context::GetSequenceInput(CustomGetNextInputFn_t& input_fn,
   return kSuccess;
   return kSuccess;
 }
 }
 
 
-int Context::SetOutputTensor(const std::string& output,
-                             CustomGetOutputFn_t output_fn,
-                             CustomPayload payload) {
-  uint32_t byte_size_with_size_int = output.size() + sizeof(int32);
-
-  // std::cout << output << std::endl;
-
-  // copy output from best_path to output buffer
-  if ((payload.error_code == 0) && (payload.output_cnt > 0)) {
-    const char* output_name = payload.required_output_names[0];
-    // output buffer
-    void* obuffer;
-    if (!output_fn(payload.output_context, output_name, output_shape_.size(),
-                   &output_shape_[0], byte_size_with_size_int, &obuffer)) {
-      payload.error_code = kOutputBuffer;
-      return payload.error_code;
+int Context::SetOutputs(kaldi::CompactLattice& clat,
+                        CustomGetOutputFn_t output_fn, CustomPayload payload) {
+	int status = kSuccess;
+  if (payload.error_code != kSuccess) return payload.error_code;
+  for (int ioutput = 0; ioutput < payload.output_cnt; ++ioutput) {
+    const char* output_name = payload.required_output_names[ioutput];
+    if (!strcmp(output_name, "RAW_LATTICE")) {
+      std::ostringstream oss;
+      kaldi::WriteCompactLattice(oss, true, clat);
+      status = SetOutputByName(output_name, oss.str(), output_fn, payload);
+      if(status != kSuccess) return status;
+    } else if (!strcmp(output_name, "TEXT")) {
+      std::string output;
+      LatticeToString(*word_syms_, clat, &output);
+      status = SetOutputByName(output_name, output, output_fn, payload);
+      if(status != kSuccess) return status;
     }
     }
+  }
 
 
-    // If no error but the 'obuffer' is returned as nullptr, then
-    // skip writing this output.
-    if (obuffer != nullptr) {
-      // std::cout << "writing " << output << std::endl;
-      int32* buffer_as_int = reinterpret_cast<int32*>(obuffer);
-      buffer_as_int[0] = output.size();
-      memcpy(&buffer_as_int[1], output.data(), output.size());
-    }
+  return status;
+}
+
+int Context::SetOutputByName(const char* output_name,
+                             const std::string& out_bytes,
+                             CustomGetOutputFn_t output_fn,
+                             CustomPayload payload) {
+  uint32_t byte_size_with_size_int = out_bytes.size() + sizeof(int32);
+  void* obuffer;  // output buffer
+  if (!output_fn(payload.output_context, output_name, output_shape_.size(),
+                 &output_shape_[0], byte_size_with_size_int, &obuffer)) {
+    payload.error_code = kOutputBuffer;
+    return payload.error_code;
   }
   }
+  if (obuffer == nullptr) return kOutputBuffer;
+
+  int32* buffer_as_int = reinterpret_cast<int32*>(obuffer);
+  buffer_as_int[0] = out_bytes.size();
+  memcpy(&buffer_as_int[1], out_bytes.data(), out_bytes.size());
+
+  return kSuccess;
 }
 }
+
 /////////////
 /////////////
 
 
 extern "C" {
 extern "C" {
@@ -395,7 +417,7 @@ int CustomExecute(void* custom_context, const uint32_t payload_cnt,
 }
 }
 
 
 }  // extern "C"
 }  // extern "C"
-}
-}
-}
-}  // namespace nvidia::inferenceserver::custom::kaldi_cbe
+}  // namespace kaldi_cbe
+}  // namespace custom
+}  // namespace inferenceserver
+}  // namespace nvidia

+ 13 - 6
Kaldi/SpeechRecognition/trtis-kaldi-backend/kaldi-backend.h

@@ -20,6 +20,7 @@
 #include <sstream>
 #include <sstream>
 #include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h"
 #include "cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.h"
 #include "fstext/fstext-lib.h"
 #include "fstext/fstext-lib.h"
+#include "lat/kaldi-lattice.h"
 #include "lat/lattice-functions.h"
 #include "lat/lattice-functions.h"
 #include "nnet3/am-nnet-simple.h"
 #include "nnet3/am-nnet-simple.h"
 #include "nnet3/nnet-utils.h"
 #include "nnet3/nnet-utils.h"
@@ -62,8 +63,13 @@ class Context {
                        const kaldi::BaseFloat** wave_buffer,
                        const kaldi::BaseFloat** wave_buffer,
                        std::vector<uint8_t>* input_buffer);
                        std::vector<uint8_t>* input_buffer);
 
 
-  int SetOutputTensor(const std::string& output, CustomGetOutputFn_t output_fn,
-                      CustomPayload payload);
+  int SetOutputs(kaldi::CompactLattice& clat,
+                      CustomGetOutputFn_t output_fn, CustomPayload payload);
+
+  int SetOutputByName(const char* output_name,
+                             const std::string& out_bytes,
+                             CustomGetOutputFn_t output_fn,
+                             CustomPayload payload);
 
 
   bool CheckPayloadError(const CustomPayload& payload);
   bool CheckPayloadError(const CustomPayload& payload);
   int FlushBatch();
   int FlushBatch();
@@ -88,6 +94,7 @@ class Context {
   int num_worker_threads_;
   int num_worker_threads_;
   std::vector<CorrelationID> batch_corr_ids_;
   std::vector<CorrelationID> batch_corr_ids_;
   std::vector<kaldi::SubVector<kaldi::BaseFloat>> batch_wave_samples_;
   std::vector<kaldi::SubVector<kaldi::BaseFloat>> batch_wave_samples_;
+  std::vector<bool> batch_is_first_chunk_;
   std::vector<bool> batch_is_last_chunk_;
   std::vector<bool> batch_is_last_chunk_;
 
 
   BaseFloat sample_freq_, seconds_per_chunk_;
   BaseFloat sample_freq_, seconds_per_chunk_;
@@ -113,7 +120,7 @@ class Context {
   std::vector<std::vector<uint8_t>> wave_byte_buffers_;
   std::vector<std::vector<uint8_t>> wave_byte_buffers_;
 };
 };
 
 
-}  // kaldi
-}  // custom
-}  // inferenceserver
-}  // nvidia
+}  // namespace kaldi_cbe
+}  // namespace custom
+}  // namespace inferenceserver
+}  // namespace nvidia

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů