[Streaming-ASR] RNN Transducer

Chris blog·2023년 9월 15일
0

1. Pros./Cons. of RNN-T

Pros

  • Better accuracy: CTC에서 존재하던 Conditional independence assumption을 해소
  • Low latency: Streaming ASR Application에 사용 가능
  • RNN-T > MoChA in terms of latency, inference time, and training stability. (Comparison study from Kim et al.)
  • The industry tends to choose RNN-T as the dominating streaming E2E model.

Cons

  • Output prediction tensor takes too much memory (3D tensor) (More detail from Moriya et al.)
  • Vanilla RNN-T can delay its label prediction (latency of ASR is critical)

2. RNN-T formulation

P(ytx1:t,y1:u1)P(y_t | x_{1:t}, y_{1:u-1})

Predicting the current token yty_t based on:

  • Previous output tokens y1:u1y_{1:u-1}
  • Speech sequence x1:tx_{1:t}

3. RNN-T Structure

  • Encoder: Generate a high-level feature representation htench_t^{enc} from xtx_t

  • Prediction network: Generate hupreh_u^{pre} based on RNN-T's previous output label yn1y_{n-1}

  • Joint network: A feed-forward network that combines htpreh_t^{pre} and htench_t^{enc} as:

    zt,u=ψ(Qhtenc+Vhupre+bz)ht,u=Wyzt,u+byP(yt=kx1:t,y1:u1)=softmax(ht,uk)z_{t,u} = \psi(Qh_t^{enc} + Vh_u^{pre}+b_z) \\ h_{t,u} = W_{y}z_{t,u}+b_y \\ P(y_t=k | x_{1:t}, y_{1:u-1})=softmax(h_{t,u}^k)

    Parameters:

    • QQ and VV are weight matrices.
    • ψ\psi is a non-linear function (e.g., RELU or Tanh)
    • zt,uz_{t,u} is again multiplied by another weight matrix WyW_y
    • bzb_z and byb_y are bias vectors

3. Shape of output

softmax(ht,uk)RT×U×Ksoftmax(h_{t,u}^k) \in \mathbb{R}^{T\times U\times K}
  • TT is the length of speech sequence
  • UU is the length of the label sequence
  • KK is the number of possible tokens including special symbols.
    (e.g., start-of-sentence, sos\langle sos \rangle, end-of-sentence, eos\langle eos \rangle and blank symbol)
  • Thus, 3D tensor that requires much more memory than other E2E models such as CTC and AED.

4. Learnable parameters

  • Prediction network parameters
  • Encoder network parameters
  • QQ, VV, bzb_z, byb_y, WyW_y from Joint network

5. Alignment Paths

  • Three possible alignment paths from the bottom left corner to the top right corner of the TTxUU grid.
  • The length of alignment path: TT+UU.
  • Horizontal arrow: Advance one time step with a blank label.
  • Vertical arrow: Advance one time step with a non-block output label.


    x-axis: Speech sequence x=(x1,x2,...,x8)x=(x_1,x_2, ..., x_8)
    y-axis: Label sequence y=(s,t,e,a,m)y=(\langle s \rangle, t,e,a,m), where s\langle s \rangle is a token for start-of-sentence.
    Delayed decision/prediction: Green path in the image above (Latency is high because of the late prediction. Problem of vanilla RNN-T.)

6. RNN-T Loss

  • RNN-T tries to minimize lnP(yx)-lnP(y|x) where

    P(yx)=aA1(y)P(ax)P(y|x) = \sum_{a \in A^{-1}(y)}P(a|x)

    aa: One of possible alignment paths
    AA: The mapping from the alignment path aa to the label sequence yy. A(a)=yA(a)=y.

  • The parameters are optimized using forward-backward algorithm (Alex et al.).

7. Forward-backward Algorithm

7.1 Implementation

(WIP)

7.2 How to improve training efficiency

  • Look skewing transformation: forward/backward probabilities can be vectorized. The recursions can be computed in a single loop instead of two nested loops.
  • Function merging: Reduce the training memory cost so that larger minibatches could be used.

8. Different Strategies for Alignments

8.1 Constrained alignment

(WIP)

8.2 FastEmit

(WIP)

8.3 Self-alignment

Summary: Self-alignment encourages the model's alignment to the left direction. (lower-latency alignment) This was reported to have better accuracy and latency tradeoff than previous methods

  • Blue path indicates a self-alignment path and the red path is one frame left to the self-alignment path.
  • During training, the method encourages the left-alignment path, pushing the model's alignment to the left direction.
profile
ChrisTechBlog

0개의 댓글