WeNet runtime uses Unified Two Pass (U2) framework for inference. U2 has the following advantages:
- Unified: U2 unified the streaming and non-streaming model in a simple way, and our runtime is also unified. Therefore you can easily balance the latency and accuracy by changing chunk_size (described in the following section).
- Accurate: U2 achieves better accuracy by CTC joint training.
- Fast: Our runtime uses attention rescoring based decoding method described in U2, which is much faster than a traditional autoregressive beam search.
- Other benefits: In practice, we find U2 is more stable on long-form speech than standard transformer which usually fails or degrades a lot on long-form speech; and we can easily get the word-level time stamps by CTC spikes in U2. Both of these aspects are favored for industry adoption.
The WeNet runtime supports the following platforms.
The following picture shows how U2 works.
When input is not finished, the input frames
When input is finished at time
We can group
We use LibTorch to implement U2 runtime in WeNet, and we export several interfaces in PyTorch python code by @torch.jit.export (see asr_model.py), which are required and used in C++ runtime in torch_asr_model.cc and torch_asr_decoder.cc. Here we just list the interface and give a brief introduction.
interface | description |
---|---|
subsampling_rate (args) | get the subsampling rate of the model |
right_context (args) | get the right context of the model |
sos_symbol (args) | get the sos symbol id of the model |
eos_symbol (args) | get the eos symbol id of the model |
forward_encoder_chunk (args) | used for the Shared Encoder module |
ctc_activation (args) | used for the CTC Activation module |
forward_attention_decoder (args) | used for the Attention Decoder module |
For streaming scenario, the Shared Encoder module works in an incremental way. The current chunk computation requries the inputs and outputs of all the history chunks. We implement the incremental computation by using caches. Overall, three caches are used in our runtime.
- Encoder Conformer/Transformer layers output cache: cache the output of every encoder layer.
- Conformer CNN cache: if conformer is used, we should cache the left context for causal CNN computation in Conformer.
- Subsampling cache: cache the output of subsampling layer, which is the input of the first encoder layer.
Please see encoder.py:forward_chunk() and torch_asr_decoder.cc for details of the caches.
In practice, CNN is also used in the subsampling. We should handle the CNN cache in subsampling. However, since there are serveral CNN layers in subsampling with different left contexts, right contexts and strides, which makes it tircky to directly implement the CNN cache in subsampling. In our implementation, we simply overlap the input to avoid subsampling CNN cache. It is simple and straightforward with negligible additional cost since subsampling CNN only costs a very small fraction of the whole computation. The following picture shows how it works, where the blue color is for the overlap part of current inputs and previous inputs.