The Algorithms logo
The Algorithms
AboutDonate

Variational Autoencoder

Y
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

#SAMPLING LAYER

class Sample(layers.Layer):
  def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
latent_dim=2

#VAE MODEL

class VAE(keras.Model):
  def __init__(self):
    super(VAE,self).__init__()

    #encoder
    encoder_inputs = keras.Input(shape=(28, 28, 1))
    x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
    x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    x = layers.Flatten()(x)
    x = layers.Dense(16, activation="relu")(x)
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
    z = Sample()([z_mean, z_log_var])
    self.encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")

    
    #decoder
    latent_inputs = keras.Input(shape=(latent_dim,))
    x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
    x = layers.Reshape((7, 7, 64))(x)
    x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
    x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
    decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
    self.decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")

  def train_step(self, data):
    if isinstance(data, tuple):
        data = data[0]
    with tf.GradientTape() as tape:
      #latent mean and variance
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            keras.losses.binary_crossentropy(data, reconstruction)
        )
        #loss function
        reconstruction_loss *= 28 * 28
        
        kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
        kl_loss = tf.reduce_mean(kl_loss)
        kl_loss *= -0.5

        total_loss = reconstruction_loss + kl_loss

    grads = tape.gradient(total_loss, self.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    return {
        "loss": total_loss,
        "reconstruction_loss": reconstruction_loss,
        "kl_loss": kl_loss,
    }
    
    
        
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

vae = VAE()
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=30, batch_size=128)
Epoch 1/30
547/547 [==============================] - 9s 17ms/step - loss: 207.9442 - reconstruction_loss: 206.5694 - kl_loss: 1.3748
Epoch 2/30
547/547 [==============================] - 9s 16ms/step - loss: 167.6381 - reconstruction_loss: 164.8430 - kl_loss: 2.7950
Epoch 3/30
547/547 [==============================] - 9s 17ms/step - loss: 157.1702 - reconstruction_loss: 154.0470 - kl_loss: 3.1231
Epoch 4/30
547/547 [==============================] - 9s 16ms/step - loss: 154.2223 - reconstruction_loss: 151.0141 - kl_loss: 3.2082
Epoch 5/30
547/547 [==============================] - 9s 16ms/step - loss: 152.4078 - reconstruction_loss: 149.1486 - kl_loss: 3.2593
Epoch 6/30
547/547 [==============================] - 9s 16ms/step - loss: 151.1136 - reconstruction_loss: 147.8070 - kl_loss: 3.3066
Epoch 7/30
547/547 [==============================] - 9s 16ms/step - loss: 150.1958 - reconstruction_loss: 146.8621 - kl_loss: 3.3336
Epoch 8/30
547/547 [==============================] - 9s 16ms/step - loss: 149.2447 - reconstruction_loss: 145.8765 - kl_loss: 3.3683
Epoch 9/30
547/547 [==============================] - 9s 16ms/step - loss: 148.5966 - reconstruction_loss: 145.1993 - kl_loss: 3.3972
Epoch 10/30
547/547 [==============================] - 9s 16ms/step - loss: 147.9715 - reconstruction_loss: 144.5540 - kl_loss: 3.4174
Epoch 11/30
547/547 [==============================] - 9s 16ms/step - loss: 147.4371 - reconstruction_loss: 143.9974 - kl_loss: 3.4396
Epoch 12/30
547/547 [==============================] - 9s 16ms/step - loss: 147.0404 - reconstruction_loss: 143.5817 - kl_loss: 3.4587
Epoch 13/30
547/547 [==============================] - 9s 16ms/step - loss: 146.5592 - reconstruction_loss: 143.0784 - kl_loss: 3.4807
Epoch 14/30
547/547 [==============================] - 9s 16ms/step - loss: 146.2075 - reconstruction_loss: 142.7118 - kl_loss: 3.4957
Epoch 15/30
547/547 [==============================] - 9s 16ms/step - loss: 145.9416 - reconstruction_loss: 142.4310 - kl_loss: 3.5106
Epoch 16/30
547/547 [==============================] - 9s 17ms/step - loss: 145.5281 - reconstruction_loss: 142.0085 - kl_loss: 3.5196
Epoch 17/30
547/547 [==============================] - 9s 16ms/step - loss: 145.2842 - reconstruction_loss: 141.7477 - kl_loss: 3.5365
Epoch 18/30
547/547 [==============================] - 9s 16ms/step - loss: 145.1024 - reconstruction_loss: 141.5528 - kl_loss: 3.5496
Epoch 19/30
547/547 [==============================] - 9s 16ms/step - loss: 144.7374 - reconstruction_loss: 141.1775 - kl_loss: 3.5599
Epoch 20/30
547/547 [==============================] - 9s 16ms/step - loss: 144.5054 - reconstruction_loss: 140.9273 - kl_loss: 3.5781
Epoch 21/30
547/547 [==============================] - 9s 16ms/step - loss: 144.3437 - reconstruction_loss: 140.7661 - kl_loss: 3.5776
Epoch 22/30
547/547 [==============================] - 9s 16ms/step - loss: 144.1328 - reconstruction_loss: 140.5432 - kl_loss: 3.5897
Epoch 23/30
547/547 [==============================] - 9s 16ms/step - loss: 143.9308 - reconstruction_loss: 140.3421 - kl_loss: 3.5887
Epoch 24/30
547/547 [==============================] - 9s 16ms/step - loss: 143.7300 - reconstruction_loss: 140.1331 - kl_loss: 3.5968
Epoch 25/30
547/547 [==============================] - 9s 16ms/step - loss: 143.5860 - reconstruction_loss: 139.9617 - kl_loss: 3.6243
Epoch 26/30
547/547 [==============================] - 9s 16ms/step - loss: 143.4559 - reconstruction_loss: 139.8398 - kl_loss: 3.6162
Epoch 27/30
547/547 [==============================] - 9s 16ms/step - loss: 143.3631 - reconstruction_loss: 139.7232 - kl_loss: 3.6399
Epoch 28/30
547/547 [==============================] - 9s 16ms/step - loss: 143.2122 - reconstruction_loss: 139.5813 - kl_loss: 3.6309
Epoch 29/30
547/547 [==============================] - 9s 16ms/step - loss: 143.0500 - reconstruction_loss: 139.4098 - kl_loss: 3.6402
Epoch 30/30
547/547 [==============================] - 9s 16ms/step - loss: 142.8536 - reconstruction_loss: 139.2033 - kl_loss: 3.6503
<tensorflow.python.keras.callbacks.History at 0x7fe18c2a9518>

#Visualising over latent space

import matplotlib.pyplot as plt


def plot_latent( ):
    # display  digits
    n = 30
    digit_size = 28
    scale = 2.0
    figsize = 15
    figure = np.zeros((digit_size * n, digit_size * n))
   
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent()