Wasserstein GAN
As you can most certainly appreciate by now, GANs have wide and varied applications, several of which apply very well to games. One such application is the generation of textures or texture variations. We often want slight variations in textures to give our game worlds a more convincing look. This is and can be done with shaders, but for performance reasons, it is often best to create static assets.
Therefore, in this section, we will build a GAN project that allows us to generate textures or height maps. You could also extend this concept using any of the other cool GANs we briefly touched on earlier. We will be using a default implementation of the Wasserstein GAN by Erik Linder-Norén and converting it for our purposes.
One of the major hurdles you will face when first approaching deep learning problems is shaping data to the form you need. In the original sample, Erik used the MNIST dataset, but we will convert the sample to use the CIFAR100 dataset. The CIFAR100 dataset is a set of color images classified by type, as follows:
CIFAR 100 dataset
For now, though, let's open up Chapter_3_wgan.py and follow these steps:
- Open the Python file and review the code. Most of the code will look the same as the DCGAN we already looked at. However, there are a few key differences we want to look at, as follows:
def train(self, epochs, batch_size=128, sample_interval=50):
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
valid = -np.ones((batch_size, 1))
fake = np.ones((batch_size, 1))
for epoch in range(epochs):
for _ in range(self.n_critic):
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
d_loss_real = self.critic.train_on_batch(imgs, valid)
d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for
w in weights]
l.set_weights(weights)
g_loss = self.combined.train_on_batch(noise, valid)
print ("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1
- g_loss[0]))\
if epoch % sample_interval == 0:
self.sample_images(epoch)
- The Wasserstein GAN uses a distance function in order to determine the cost or loss for each training iteration. Along with this, this form of GAN uses multiple critics rather than a single discriminator to determine cost or loss. Training multiple critics together improves performance and handles the vanishing gradient problem we often see plaguing GANs. An example of a different form of GAN training is as follows:
- A WGAN overcomes the gradient problem by managing cost through a distance function that determines the cost of moving, rather than a difference in error values. A linear cost function could be as simple as the number of moves a character needs to take in order to spell a word correctly. For example, the word SOPT would have a cost of 2, since the T character needs to move two places to spell STOP correctly. The word OTPS has a distance cost of 3 (S) + 1 (T) = 4 to spell STOP correctly.
- The Wasserstein distance function essentially determines the cost of transforming one probability distribution to another. As you can imagine, the math to understand this can be quite complex, so we will defer that to the more interested reader.
- Run the example. This sample can take a significant time to run, so be patient. Also, this sample has been shown to have trouble training on some GPU hardware. If you find this to be the case, just disable the use of GPU.
- As the sample runs, open the images folder from the same folder as the Python file and watch the training images generate.
Run the sample for as long as you feel the need to in order to understand how it works. This sample can take several hours even on advanced hardware. When you are done, move on to the next section, and we will see how to modify this sample for generating textures.