CTC Loss and the Dynamic Programming Approach

Table of Contents

This post serves as a minimal appraoch to understanding the CTC loss (for myself).

Goal

To derive a suitable loss function for training a RNN with back-propagation (so has to be differentiable wrt. input at each time step)

Basic Idea

  • Interpret network outputs as probability distribution over all possible sequences of labels

  • Introduce blanks $\epsilon$ in output; e.g. for the true label sequence $\mathbf{l}$, we can insert $\epsilon$ at the start and the end, as well as between adjacent characters to create a modified true label $\mathbf{l’}$, for the word $hello$ this is \(\epsilon-h-\epsilon-e-\epsilon-l-\epsilon-l-\epsilon-o-\epsilon\)

  • We map a network output sequence to the true label by

    • removing blanks between two distinct labels e.g. $h-\epsilon-e$ becomes $he$
    • collapsing repeated labels that have no blanks inbetween e.g. $h-h-e$ becomes $he$
    • keeping repeated labels with blanks in between e.g. $l-\epsilon-l$ becomes $ll$
    • removing blanks at the very start and end
  • The sequences that can be mapped to the true label sequence are called valid

  • So we can compute (assuming conditional independence between network outputs)
    \(p(\pi | \mathbf{x}) \ = \ \mathbb{P}od_{t=1}^Ty_{\pi_t}^t\)
  • $\pi$ is a valid sequence, $\mathbf{x}$ is an input to the network, $\pi_t$ is its entry at time $t$, and $y_{\pi_t}^t$ is the probability of observing that $\pi_t$ at time $t$ (the network output).
  • So total probability of obtaining a true sequence $\mathbf{l}$ is \(p(\mathbf{l} | \mathbf{x}) \ = \ \sum_{valid \ \pi} p(\pi | \mathbf{x})\)

Training

  • Objective function is the negative max log likelihood which we wish to minimise.

  • Need to compute derivative of error wrt. output at each time step $t$, $y_k^t$, i.e. the probability of getting $k^{th}$ ground truth label at time $t$.

  • So we need a way to express negative log likelihood in terms of those proabilities.

Dynamic Programming Solution

  • We have two goals:

    Computing the total probability $p(\mathbf{l} | \mathbf{x})$.
    Computing the total probability in terms of $y_k^t$

  • For the first one, we use the so called Forward Algorithm, which defines $\alpha_t(s)$ as the proability of obtaining a valid sequence at time $t$ and covering the modified $\mathbf{l’}$ till the $s^{th}$ character. We allow it to start at either $\epsilon$ or the first ground truth label

\(\alpha_t(s)=\begin{cases} (\alpha_{t-1}(s)+\alpha_{t-1}(s-1))y_{\mathbf{l_s'}}^t & \mathbf{l_s'} \ = \ \epsilon \\ (\alpha_{t-1}(s)+\alpha_{t-1}(s-1))y_{\mathbf{l_s'}}^t & \mathbf{l_s'} \ = \ \mathbf{l_{s-2}'} \\ (\alpha_{t-1}(s)+\alpha_{t-1}(s-1)+\alpha_{t-1}(s-2))y_{\mathbf{l_s'}}^t & otherwise \\ \end{cases}\)

  • The first two cases above are clear, since $\alpha_{t-1}(s)$ accounts for repeated labels e.g. $\epsilon \epsilon$, while $\alpha_{t-1}(s-1)$ corresponds to the case the first time this $\mathbf{l_s’}$ is encountered

  • For the third case, the extra $\alpha_{t-1}(s-2)$ is added as we can safely ignore the $(s-1)^{th}$ label, as it’s either repeated or blank.

  • Therefore, the desired probability is given by the sum of those with final blanks and those without at time $T$

\[p(\mathbf{l} | \mathbf{x}) \ = \ \alpha_{T}(|\mathbf{l'}|) + \ \alpha_{T}(|\mathbf{l'}|-1)\]
  • Now the Backward Algorithm defines variables $\beta_t(s)$ similarly; $\beta_t(s)$ gives the probability of obtaining a valid sequence at time $t$ and covering the last $s$ characters of the modified $\mathbf{l’}$.
    \(\beta_t(s)=\begin{cases} (\beta_{t+1}(s)+\beta_{t+1}(s+1))y_{\mathbf{l_s'}}^t & \mathbf{l_s'} \ = \ \epsilon \\ (\beta_{t+1}(s)+\beta_{t+1}(s+1))y_{\mathbf{l_s'}}^t & \mathbf{l_s'} \ = \ \mathbf{l_{s-2}'} \\ (\beta_{t+1}(s)+\beta_{t+1}(s+1)+\beta_{t+1}(s+2))y_{\mathbf{l_s'}}^t & otherwise \\ \end{cases}\)

  • Now we can compute the derivative of the negative log-likelihood
    \(-\frac{\partial \ ln(p(\mathbf{z}|\mathbf{x}))}{\partial y_t^k}\)

  • Note from both algorithms and law of total probability ($y_{\mathbf{l_s’}}^t$ is repeated twice in each variable) \(p(\mathbf{l} | \mathbf{x}) \ = \ \sum_{s=1}^{|\mathbf{l'}|} \ \frac{\alpha_t(s) \ \beta_t(s)}{y_{\mathbf{l_s'}}^t}\)
  • Now computation becomes easy.

Conclusion

While traditional loss functions like MSE assume a bijection between output and input labels, it’s not often the case in tasks like speech recognition.

The CTC loss provided us with a way to quantify the error of output by interpreting outputs as a probability distribution given input. It does not require alignment or segmentation of inputs.

References

[1] https://www.youtube.com/watch?v=UMxvZ9qHwJs&t=542s

[2] Graves et al. (2006). Connectionist temporal classification: Labelling unsegmented sequence data with recurrent neural ‘networks.

MATH50003 Week1 A bit of combinatorics in the proof of Taylor