CTC Loss and the Dynamic Programming Approach
This post serves as a minimal appraoch to understanding the CTC loss (for myself).
To derive a suitable loss function for training a RNN with back-propagation (so has to be differentiable wrt. input at each time step)
Basic Idea
Interpret network outputs as probability distribution over all possible sequences of labels
Introduce blanks $\epsilon$ in output; e.g. for the true label sequence $\mathbf{l}$, we can insert $\epsilon$ at the start and the end, as well as between adjacent characters to create a modified true label $\mathbf{l’}$, for the word $hello$ this is \(\epsilon-h-\epsilon-e-\epsilon-l-\epsilon-l-\epsilon-o-\epsilon\)
We map a network output sequence to the true label by
- removing blanks between two distinct labels e.g. $h-\epsilon-e$ becomes $he$
- collapsing repeated labels that have no blanks inbetween e.g. $h-h-e$ becomes $he$
- keeping repeated labels with blanks in between e.g. $l-\epsilon-l$ becomes $ll$
- removing blanks at the very start and end
The sequences that can be mapped to the true label sequence are called valid
- So we can compute (assuming conditional independence between network outputs)
\(p(\pi | \mathbf{x}) \ = \ \mathbb{P}od_{t=1}^Ty_{\pi_t}^t\) - $\pi$ is a valid sequence, $\mathbf{x}$ is an input to the network, $\pi_t$ is its entry at time $t$, and $y_{\pi_t}^t$ is the probability of observing that $\pi_t$ at time $t$ (the network output).
- So total probability of obtaining a true sequence $\mathbf{l}$ is \(p(\mathbf{l} | \mathbf{x}) \ = \ \sum_{valid \ \pi} p(\pi | \mathbf{x})\)
Objective function is the negative max log likelihood which we wish to minimise.
Need to compute derivative of error wrt. output at each time step $t$, $y_k^t$, i.e. the probability of getting $k^{th}$ ground truth label at time $t$.
So we need a way to express negative log likelihood in terms of those proabilities.
Dynamic Programming Solution
- We have two goals:
Computing the total probability $p(\mathbf{l} | \mathbf{x})$.
Computing the total probability in terms of $y_k^t$ - For the first one, we use the so called Forward Algorithm, which defines $\alpha_t(s)$ as the proability of obtaining a valid sequence at time $t$ and covering the modified $\mathbf{l’}$ till the $s^{th}$ character. We allow it to start at either $\epsilon$ or the first ground truth label
\(\alpha_t(s)=\begin{cases} (\alpha_{t-1}(s)+\alpha_{t-1}(s-1))y_{\mathbf{l_s'}}^t & \mathbf{l_s'} \ = \ \epsilon \\ (\alpha_{t-1}(s)+\alpha_{t-1}(s-1))y_{\mathbf{l_s'}}^t & \mathbf{l_s'} \ = \ \mathbf{l_{s-2}'} \\ (\alpha_{t-1}(s)+\alpha_{t-1}(s-1)+\alpha_{t-1}(s-2))y_{\mathbf{l_s'}}^t & otherwise \\ \end{cases}\)
The first two cases above are clear, since $\alpha_{t-1}(s)$ accounts for repeated labels e.g. $\epsilon \epsilon$, while $\alpha_{t-1}(s-1)$ corresponds to the case the first time this $\mathbf{l_s’}$ is encountered
For the third case, the extra $\alpha_{t-1}(s-2)$ is added as we can safely ignore the $(s-1)^{th}$ label, as it’s either repeated or blank.
Therefore, the desired probability is given by the sum of those with final blanks and those without at time $T$
Now the Backward Algorithm defines variables $\beta_t(s)$ similarly; $\beta_t(s)$ gives the probability of obtaining a valid sequence at time $t$ and covering the last $s$ characters of the modified $\mathbf{l’}$.
\(\beta_t(s)=\begin{cases} (\beta_{t+1}(s)+\beta_{t+1}(s+1))y_{\mathbf{l_s'}}^t & \mathbf{l_s'} \ = \ \epsilon \\ (\beta_{t+1}(s)+\beta_{t+1}(s+1))y_{\mathbf{l_s'}}^t & \mathbf{l_s'} \ = \ \mathbf{l_{s-2}'} \\ (\beta_{t+1}(s)+\beta_{t+1}(s+1)+\beta_{t+1}(s+2))y_{\mathbf{l_s'}}^t & otherwise \\ \end{cases}\) -
Now we can compute the derivative of the negative log-likelihood
\(-\frac{\partial \ ln(p(\mathbf{z}|\mathbf{x}))}{\partial y_t^k}\) - Note from both algorithms and law of total probability ($y_{\mathbf{l_s’}}^t$ is repeated twice in each variable) \(p(\mathbf{l} | \mathbf{x}) \ = \ \sum_{s=1}^{|\mathbf{l'}|} \ \frac{\alpha_t(s) \ \beta_t(s)}{y_{\mathbf{l_s'}}^t}\)
- Now computation becomes easy.
While traditional loss functions like MSE assume a bijection between output and input labels, it’s not often the case in tasks like speech recognition.
The CTC loss provided us with a way to quantify the error of output by interpreting outputs as a probability distribution given input. It does not require alignment or segmentation of inputs.
[1] https://www.youtube.com/watch?v=UMxvZ9qHwJs&t=542s
[2] Graves et al. (2006). Connectionist temporal classification: Labelling unsegmented sequence data with recurrent neural ‘networks.