Pruned RNN-T Explained

This blog explains the paper- ``Pruned RNN-T for fast, memory-efficient ASR training''

Paper theme

RNN-T has a slow and memory-intensive loss function, limiting its use for large vocabularies like Chinese characters. The main idea of Pruned RNN-T is to reduce computational and memory requirements by selectively evaluating the joiner network, leading to faster training and maintained accuracy. Pruning bounds are obtained using a linear joiner network in the encoder and decoder embeddings.

Source Code

Introduction

Motivations

Illustration of a generic transducer model. Reference (Lugosch, 2020)

Tokens:

Equations:

This illustration shows the lattice in standard RNN-T (right) vs pruned RNN-T (left). arXiv (Fangjun Kuang, 2022)
What is a lattice? and why is it important?

The lattice represents the log-probs of transition between the time steps and label indices. They are important because they capture the likelihood of transition to a token at a time step.

Pruned RNN-T

In the context of the pruned RNN-T model, the quantities \(y′(t, u)\) and \(\phi′(t, u)\) represent “occupation counts” within a specific interval, indicating the likelihood of upward and rightward transitions. Lets estimate the retained probability mass for \(S=4\) (no. of label indices to be evaluated at each time step \(t\)) and \(p_t=2\) (pruning bound)–

\[\phi'(t,2) + \phi'(t,3) + \phi'(t,4) + \phi'(t,5) - y'(t,1)\]

The occupation count \(y'\) in the above equation represents the probability associated with label index 1 which is red in the above lattice diagram. The subtraction is done to make the calculation more accurate by compensating for the inclusion of some probability mass that should have been pruned out due to lower \(u\) values.

\[L_{smoothed} = (1- \alpha^{lm} - \alpha^{acoustic})L_{trivial} + \alpha^{lm}L_{lm} + \alpha^{acoustic}L_{lm}\]

Refer to Section 3.3 in the paper for a more detailed explanation

Experimental Settings

Category Parameter/Hyperparameter Value
Dataset Corpus LibriSpeech
Training Hours 960 hours
Test Sets test-clean, test-other
Test Set Speech Duration Approximately 5 hours each
Input Features Feature Type 80-dimension log Mel filter bank
Window Size 25 ms
Window Shift 10 ms
Data Augmentation SpecAugment Factors 0.9 and 1.1
Model Architecture Encoder Conformer with 12 layers
Encoder Self-Attention 8 heads
Encoder Attention Dim 512
Encoder Feed-Forward Dim 2048
Decoder Stateless decoder with embedding layer and 1-D convolutional layer
Decoder Embedding Dim 512
GPUs Number of GPUs 8 NVIDIA V100 32GB GPUs
Training Strategy Pruning Strategy Enable pruned loss after convergence of trivial loss

Results