Exploring Data with Generative Clustering

By Max Candocia

|

January 09, 2018

When working with data, oftentimes one may want to cluster data into different groups based on features. Clustering data involves assigning a group label to each element of a data set (usually row, but possibly factor levels of one or more columns), usually based on either a distance function from a group "centroid" or a probability distribution describing a group. There could be several reasons for this:

  1. The features are high-dimensional and you would like to reduce the dimensionality
  2. Possibly because the features are high-dimensional, you would like to perform another task (such as visualization, regression, or even more clustering with different features) on each individual cluster.
  3. You want to use the clustering criteria to determine what makes different groups the most different
  4. You want to know what rows/objects are most similar to each other

Using keywords in documents as an example of each of these:

  1. You have a bunch of keywords describing documents, and you want to use a document's "topic" as a feature in a regression (e.g., how popular is a document based on topic, word count, publication date, etc.?). The topic can only be inferred using clustering methodology in cases without any pre-assigned labels.
  2. You have a bunch of keywords describing documents, and you would like to create a different model for each topic, since you believe the features are too dissimilar in underlying structure across different topics.
  3. You have a bunch of keywords describing documents, and you want to know what keywords contribute to a document being placed in a particular topic.
  4. You have a bunch of keywords describing documents, and you want to recommend similar documents to a user viewing one of them.

Any of the above can also be true for datasets involving numeric variables, which is what I will focus on below with a particular method of clustering known as the expectation-maximization algorithm. This algorithm essentially involves these steps:

  1. Select a number of clusters, and randomly assign a cluster to each row of the data
  2. For each cluster, estimate a distribution/centroid based on pre-selected features
  3. For each row of data, reassign the cluster with the highest probability (density) or smallest distance to a centroid
  4. Repeat steps 2-3 until the cluster assignments are stable/a pre-determined maximum number of iterations have been reached

Here is an example of an algorithm in action using two variables in a graph from the classic Iris dataset:

source('generative_model.r')
require(gganimate)
require(ggplot2)
set.seed(2018)
iris.model_free_simple = generative_model(iris, variables=names(iris)[3:4], n_groups=2,
                                   covariance_constraint='free',
                                   record_history=TRUE,
                                   include_data_copy=TRUE,
                                   n_iter=30,
                                   prior_smoothing_constant = 0.1)
hf = create_history_frame(iris.model_free_simple)
p = ggplot(hf, aes(x=Petal.Length, y=Petal.Width, color=factor(group), frame=factor(iteration))) +
  geom_point() + ggtitle('Clustering of Iris data: Iteration #') +
  stat_ellipse(geom='polygon', aes(fill=factor(group)), alpha=0.5 )
gganimate(p, title_frame=TRUE, 
filename='http://maxcandocia.com/media/articles/generative-clustering/simple_example.gif',saver='gif')
animated clustering

As you can see, initially, the initial designations are completely random, but the algorithm is able to detect two separate clusters, describing each as a multivariate normal distribution. This type of distribution is good for detecting clusters that are centered around ellipsoid shapes at arbitrary angles.

One should also note that some data/number of cluster combinations are not stable, and depend on what the initial cluster assignments are. The more unstable a combination is, the less likely that either of the following are true:

  1. The number of clusters is appropriate for the data and algorithm
  2. You are using an appropriate algorithm for clustering the data (if it can even be clustered)

source code

Clustering to Detect Underlying Structure

Using all four columns of the Iris data set (sepal & petal width & length), one can attempt to see how well clusters correlate to the three iris species, Setosa, Versicolor, and Virginica.

First, we can use the k-means algorithm. My implementation is very similar to the above, except all dimensions and all clusters have the same variance, and all correlations are removed.

kmeans clustering example
The clustering algorithm can perfectly isolate the Setosa species, and does a decent job of separating Versicolor and Virginica
1 2 3
setosa 0 50 0
versicolor 46 0 4
virginica 11 0 39

The clusters seem to do a fairly good job of isolating the different species, even though the species was never used at all to create the clusters.

If we relax our constraints, we can see the cluster shapes become more useful in describing correlations.

model clustering example
The clustering algorithm can perfectly isolate the Setosa species, and does a decent job of separating Versicolor and Virginica
1 2 3
setosa 0 50 0
versicolor 48 0 2
virginica 4 0 46

The clustering almost perfectly isolates each of the species, and this is a completely unsupervised model!. This model doesn't assume that each variable has the same variance, which is a pretty safe assumption. It only calculates one variance per variable, though, which does not add much complexity to the model.


free-diagonal modeling clustering example
The clustering algorithm can perfectly isolate the Setosa species, and does a decent job of separating Versicolor and Virginica
1 2 3
setosa 50 0 0
versicolor 0 15 35
virginica 0 48 2

Here we allow the variances to vary across variables and clusters. This does have the consequence of some clusters partially engulfing other clusters, though. While this may be appropriate in some situations, there's no reason to believe that the distributions of variables should have vastly different shapes between clusters.

The fact that the clusters correlate less with the underlying class suggests that this model is not as good as describing the underlying structure of the data as the previous model.


unconstrained clustering example
The clustering algorithm can perfectly isolate the Setosa species, and does a decent job of separating Versicolor and Virginica
1 2 3
setosa 0 50 0
versicolor 13 22 15
virginica 18 6 26

It is interesting to see that when one adds more degrees of freedom (in this case correlations) the algorithm becomes more unstable. You can see this from the ellipses reaching across clusters, since the high correlations fail to penalize distance as much. And the Setosa label becomes muddled with other species, making it the worst-performing clustering algorithm.

Evaluating Cluster Quality

How does one measure the quality of the clusters above? The simplest way to evaluate the quality of clusters like these that have underlying probabilities attached to them is to calculate their deviance.

The deviance is simply -2 * log(model_likelihood) , where the model likelihood is simply the product of all of the probabilities (for categorical variables)/probability densities (for numeric variables) for each assignment. The higher the deviance, the worse the fit. Below are the deviances for the different models shown above:

generative.deviance <- function(model, use_density=TRUE, epsilon=1e-14){
  if (use_density)
    return(-2 * sum ( log(pmax(epsilon, apply(model$prob_densities, 1, max)))))
  else
    return(-2 * sum ( log(apply(model$probs, 1, max))))
}
#k-means clustering
generative.deviance(iris.model_kmeans)
## [1] 823.2612
#clustering with different variances per parameter
generative.deviance(iris.model_constant)
## [1] 728.1573
#clustering with different variances per parameter and cluster
generative.deviance(iris.model_diagonal)
## [1] 625.2206
#clustering with no restrictions on covariances
generative.deviance(iris.model_free)
## [1] 619.4432

It looks like the third and fourth model have the lowest deviances, which is expected.

Fixing Poorly Fit Cluster

One thing that can be done to fix poorly fit clusters is to just keep on trying different random starting values until you get a desirable result. Since deviance is a quick and easy measure, we can try, say, 51 different iterations for the most complicated model and see how much the results improve.

best_deviance = Inf
best_seed = -1
best.model = NULL

for (seed in 2000:2050){
  set.seed(seed)
  gen.model = generative_model(iris[,1:4],
                               n_groups=3,
                               covariance_constraint='free',
                               record_history=TRUE)
  dev = generative.deviance(gen.model)
  if (dev < best_deviance){
    best_deviance = dev
    best_seed = seed
    best.model = gen.model
  }
}
print(best_deviance)
## [1] 377.0754
print(best_seed)
## [1] 2004

It looks like the best deviance is about 0.8, which is 5 times lower than the previoius best. Let's see how this model was created!

'best' clustering example
The clustering algorithm can perfectly isolate the Setosa species, and does a decent job of separating Versicolor and Virginica
1 2 3
setosa 0 0 50
versicolor 49 1 0
virginica 16 34 0

It appears as if the "best" model did not really describe the structure very well. It mostly isolated Setosa, but almost all of Virginica and Versicolor are together, with a remaining sliver of points being in cluster 3, which is a definite example of overfitting in unsupervised learning. There is no reason to believe that such a small, elongated cluster should have any meaning in this context.

Moral of the story: simpler structures in data tend to be more meaningful than complex ones.

GitHub Code

All of the code for generating this article and its images can be found at https://github.com/mcandocia/generative_clustering.


Tags: 

Recommended Articles

Overlaying Density Heatmaps on Geographic Maps in R

In this example, I use noise complaint data from New York City to demonstrate how you can plot densities of events on a map, as well as how extreme the averages are.

Visualizing My Runs in 2017

A visualization of my running from 2017 using ggplot2.