Generative Adversarial Networks for Text Generation — Part 1
The issues with GANs for text generation and the methods being used to combat them
It’s no secret that Generative Adversarial Networks (GANs) have become a huge success in the Computer Vision world for generating hyper-realistic images. Some of the samples produced by the most recent GAN variants are astonishing. Here are a few of them picked from a recent paper:
Building on their success in generation, image GANs have also been used for tasks such as data augmentation, image upsampling, text-to-image synthesis and more recently, style-based generation, which allows control over fine as well as coarse features within generated images.
Trending AI Articles:
Like a lot of people, my first reaction to learning about these results was, “Let’s apply this to text!”. But, when I finally got to working out all the details of the model, the issues became clear. Before we look at what they are, let’s take a quick look at what GANs are. Since there are plenty of detailed tutorials out there, I will keep it brief.
Generative Adversarial Networks (GANs)
The simplest way of looking at a GAN is as a generator network that is trained to produce realistic samples by introducing an adversary i.e. the discriminator network, whose job is to detect if a given sample is “real” or “fake”. Another way that I like to look at it is that the discriminator is a dynamically-updated evaluation metric for the tuning of the generator. Both, the generator and discriminator continuously improve until an equilibrium point is reached:
- The generator improves as it receives feedback as to how well its generated samples managed to fool the discriminator.
- The discriminator improves by being shown not only the “fake” samples generated by the generator, but also “real” samples drawn from a real-life distribution. This way it learns what generated samples look like and what real samples look like, thus enabling it to give better feedback to the generator.
This is essentially a minimax game played by the two networks, whose value function is given by:
— D(x) is the probability that x is “real” according to the discriminator.
— G(z) is a sample generated by the generator given a latent vector (z).
Intuitively, the value function says that:
- The discriminator wants to maximize the probability of the real data being identified as “real” and the generated data being identified as “fake”.
- The generator wants to minimize the probability that the discriminator identifies its generated data as “fake”.
Though this intuition will suffice for the purposes of this article, I highly recommend you read a more in-depth mathematical analysis of GANs and their problems here.
GANs for Images
Let’s look take a brief look at how GANs work for images.
In the case of images, G(z) is the generated image and x is sampled from a dataset of real images.
Here, the output of the generator, G(z), is simply a matrix of real values which we interpret as an image. The discriminator then takes this matrix of real values as input and classifies it as fake (0) or real (1).
Pretty straightforward, right?
So, what’s the issue with text?
First, let’s take a look at how text generation is done using a simple RNN-based text generator.
At every time step t, the RNN takes the previously generated token and the previous hidden state as input and generates the new hidden state, hᵗ.
The hidden state is then passed through a linear layer and softmax layer followed by argmax to yield the next word.
The RNN is trained by making it predict the next word in a sentence at each time step. Training is done by back-propagating the cross-entropy loss between the output distribution of the softmax layer and the target one-hot vector.
Now, consider this RNN-based generator to be the generator network in a GAN. Here, the latent vector z is the input hidden state h⁰ of the RNN, and the generator output G(z) is the sentence output by the RNN. The difference here, is that instead of training the RNN to minimize cross-entropy loss with respect to target one-hot vectors, we will be training it to increase the probability of the discriminator network classifying the sentence as “real”, i.e., the objective now is to minimize 1 - D(G(z)).
Remember that while decoding using an RNN, at every time step we make the choice of the next word by picking the word corresponding to the maximum probability from the output of the softmax function. This “picking” operation is non-differentiable.
Why is this an issue? It’s an issue because, in order to train the generator to minimize 1 - D(G(z)), we need to feed the output of the generator to the discriminator and back-propagate the corresponding loss of the discriminator. For these gradients to reach the generator, they have to go through the non-differentiable “picking” operation at the output of the generator. This is problematic as back-propagation relies on the differentiability of all the layers in the network.
In contrast, note that this is perfectly feasible when the generated data is continuous, such as images, as shown above.
In recent times, people have proposed various methods to circumvent this issue. These methods can broadly be classified into:
- The REINFORCE algorithm and policy gradients (Reinforcement Learning-based solutions)
- The Gumbel-Softmax approximation (A continuous approximation of the softmax function)
- Avoiding discrete spaces altogether by working with the continuous output of the generator
I will be delving deeper into these approaches in the next part of this series.
In this post, we briefly looked at GANs and the reason they pose problems when applied to text generation. Check out Part-2 (coming out soon) to learn about the above mentioned solutions in detail. Hope you enjoyed reading!
Don’t forget to give us your 👏 !
Generative Adversarial Networks for Text Generation — Part 1 was originally published in Becoming Human: Artificial Intelligence Magazine on Medium, where people are continuing the conversation by highlighting and responding to this story.