For tasks that involve generating natural language, a common practice is to train the model in teacher forcing mode. This means that during training, the model is always asked to predict the next word given the previous ground truth words as its input. However, at test time the model is expected to generate the next word based on its previously generated words. As a result, the mistakes the model has made along the way can quickly accumulate as it has never been exposed to its own errors during training. This phenomenon is known as exposure basis.
These models are predominantly trained via Maximum Likelihood Estimation (MLE), which may not correspond to perceived quality of the generated text. Maximizing the likelihood is equivalent to minimizing the Kullback-Leibler (KL) divergence between the real data distribution P and the estimated model distribution G. However, the KL divergence is asymmetric and has well known limitations when used for training. Thus, this paper proposes to optimize the Jensen-Shannon divergence (JSD) between P and G instead.
The JSD requires an intermediate distribution which is a mixture between the true distribution P and the model distribution Q. In this paper, the authors suggest the use of a mediator which approximates this intermediate distribution. The generative model is then trained to minimize the estimated JSD provided by the mediator. This results in an iterative algorithm that alternates between updating the mediator and the generative model. It can also be viewed as a cooperative objective between the mediator and the generator to maximize the expected log likelihood for both P and G, which gives rise to the paper’s name Cooperative Training (CoT).
Although experiments in the paper are mostly language related, it is worth mentioning that such a strategy can, in principle, be applied to many other types of data. One caveat, however, is that applying CoT requires a factorized version of the density function, which is trivial for natural language since it can be represented using an RNN language model. For instance for images, one could opt to use a model like PixelRNN. However it may be prohibitively slow.