Generative models - A gentle introduction

Generative models - A gentle introduction

Machine learning models can be broadly divided into discriminative or generative models.

Broad classification of Machine Learning Models

To understand the idea behind both these models, Lets take a toy dataset of just 4 samples. One male aged more than 30 years . One male aged less than 30. Two female aged more than 30. Lets say we fit a toy model to this data where the inputs x to the model is the gender  and the output y  is the classification as to whether the person is less than or more than 30 years of age.

Lets say we have trained a model with this dataset . During prediction, whenever we get a female as input  the model will predict the age of the person is more than 30. To dive into the reason, lets have a look at our training data and convert the counts into probabilities.

If I tell you the person is a male,there are only two males in the training dataset. So, The chance of getting the output to be less than 30 is 1 in 2 or 1/2 . Similarly the chance of getting the output as greater than 30 is 1 in 2  simply because I have already told you the person is a male. Arguing in the same way, if I tell you the person is a female, as per our dataset, there are two females and both of them are aged more than 30. So surprisingly the probability of getting the output as more than 30 is 1.

What we have just computed is called the conditional probability distribution.  This way of building models by conditioning is called discriminative modelling. Discriminative models learn sample by sample and so train by drawing a boundary line between different classes. It also means that discriminative models are almost always supervised models.

Now lets look at the same dataset from a different perspective. Lets consider all the 4 samples together  and convert them into probabilities. There are a total of 4 people with only one male aged more than 30 out of 4 and so the probability is 1 in 4  Likewise, the probability of a male person aged less than 30 is 1 / 4 . There are 2 females out of 4 aged more than 30. So, in terms of probabilities its 1/2  Finally there are no females aged less than 30 and so the probability is 0This method of calculating the probabilities considering all the samples jointly leads to the joint distribution p(x,y)  and if a model learns this joint distribution it is called a generative model . Unlike the discriminative models, the generative models jointly learns the probability of the entire dataset without drawing any boundaries. And so, generative models are well suited for almost all unsupervised problems.

Speciality of generative models

Imagine we are scaling up this toy problem 100s of fold with our training dataset containing thousands of examples of cats and dogs. With a generative model trained this way, we can pull out samples from the distribution learnt by the model and we will end up with images at the output. If we pull out another sample from different part of the distribution, we could end up with samples of dogs. But note that we don't have any control over what the output image will be. Still generative models built this way have a huge potential and state-of-the-art models such as StyleGAN and Glow are very good examples of models that are built this way.

Conditional Generative Models

To understand how we can control the generated output, lets go back to our toy example with the joint distribution of the generative model p(x,y). if we know the labels y of our training data, we can calculate the distribution of the labels p(y) .  we can then simply divide  the joint distribution by the distribution of labels y and we arrive at what is called the conditional probability distribution, p(x | y) . With this solution, we can input a label cat  and ask the model to output x which will be an image of a cat. Or we can input a label dog and get an image of a dog  as the output. The solution gives birth to what is called the conditional generative models  and is the fundamental building block of all the fancy generative models such as Dall-e that take text as input and generate an image at the output.

This post is just a primer to a series of posts about generative models. If you like what you read, why not subscribe to our newsletter and get these articles delivered right to your inbox.