The full code is available here: https://github.com/wiseodd/generative-models.
Vanilla GAN is a method to learn marginal distribution of data
Coupled GAN (CoGAN) is a method that extends GAN so that it could learn joint distribution, by only needing samples from the marginals. What it means is that we do not need to sample from joint distribution
Learning joint distribution by sharing weights
So, how exactly does CoGAN learn joint distribution by only using the marginals?
The trick here is to add a constraint such that high level representations of data are shared. Specifically, we constraint our networks to have the same weights on several layers. The intuition is that by constraining the weights to be identical to each other, CoGAN will converge to the optimum solution where those weights represent shared representation (joint representation) of both domains of data.
But which layers should be constrained? To answer this, we need to observe that neural nets that are used for classification tasks learn data representation in bottom-up fashion, i.e. from low level representation to high level representation. We notice that low level representation is highly specialized on data, which is not general enough. Hence, we constraint our neural net on several layers that encode the high level representation.
Intuitively, the lower level layers capture image specific features, e.g. the thickness of edges, the saturation of colors, etc. But, higher level layers capture more general features, such as the abstract representation of “bird”, “dog”, etc., ignoring the color or the thickness of the images. So, naturally, to capture joint representation of data, we want to use higher level layers, then use lower level layers to encode those abstract representation into image specific features, so that we get the correct (in general sense) and plausible (in detailed sense) images.
Using that reasoning, we then could choose which layers should be constrained. For discriminator, it should be the last layers. For generator, it should be the first layers, as generator in GAN solves inverse problem: from latent representation
CoGAN algorithm
If we want to learn joint distribution of
The algorithm for CoGAN for 2 domains is as follows:
Notice that CoGAN draws samples from each marginal distribution. That means, we only need 2 sets of training data. We do not need to construct specialized training data that captures joint distribution of those two domains. However, as we learn joint distribution by weight sharing on high level features, to make CoGAN training successful, we have to make sure that those two domains of data share some high level representations.
Pytorch implementation of CoGAN
In this implementation, we are going to learn joint distribution of two domains of MNIST data: normal MNIST data and rotated MNIST data (90 degree). Notice that those domains of data share the same high level representation (digit), and only differ on the presentation (low level features). Here’s the code to generate those training sets:
Let’s declare the generators first, which are two layers fully connected nets, with first weight (input to hidden) shared:
Then we make a wrapper for those nets:
Notice that G_shared
are being used in those two nets.
The discriminators are also two layers nets, similar to the generators, but share weights on the last section: hidden to output.
Next, we construct the optimizer:
Now we are ready to train CoGAN. At each training iteration, we do these steps below. First, we sample images from both marginal training sets, and
Then, train the discriminators by using using X1
for D1
and X2
for D2
. On both discriminators, we use the same z
. The loss function is just vanilla GAN loss.
Then we just add up those loss. During backpropagation, D_shared
will naturally get gradients from both D1
and D2
, i.e. sum of both branches. All we need to do to get the average is to scale them:
As we have all the gradients, we could update the weights:
For generators training, the procedure is similar to discriminators training, where we need to average the loss of G1
and G2
w.r.t. G_shared
.
Results
After many thousands of iterations, G1
and G2
will produce these kind of samples. Note, first two rows are the normal MNIST images, the next two rows are the rotated images. Also, the G1
and G2
are the same so that we could see given the same latent code
Obviously, if we swap our nets with more powerful ones, we could get higher quality samples.
If we squint, we could see that roughly, images at the third row are the 90 degree rotation of the first row. Also, the fourth row are the corresponding images of the second row.
This is a marvelous results considering we did not explicitly show CoGAN the samples from joint distribution (i.e. a tuple of
Conclusion
In this post, we looked at CoGAN: Coupled GAN, a GAN model that is used to learn joint distribution of data from different domains.
We learned that CoGAN learned joint distribution by enforcing weight sharing constraint on its high level representation weights. We also noticed that CoGAN only needs to see samples from marginal distributions, not the joint itself.
Finally, by inspecting the samples acquired from generators, we saw that CoGAN correctly learns joint distribution, as those samples are correspond to each other.
References
- Liu, Ming-Yu, and Oncel Tuzel. “Coupled generative adversarial networks.” Advances in Neural Information Processing Systems. 2016.