Clustering with K-Means with Spark and MLlib

MLlib provides a parallelized clustering algorithm called kmeans||, which allows us to have an efficiently parallelized algorithm on Spark.  Clustering is an unsupervised machine learning that helps us discover natural patterns in data.

What is k-means?  It’s about the simplest algorithm out there, it essentially starts with a given number (k) of centroids, and randomly iterates the centroids until they converge.  Points are then assigned to the nearest centroid, making a cluster.  It works great for data which is clusterable by circles/spheres (actually hyperspheres).  For data which has more convoluted patterns, such as rings and other shapes, we can use hierarchical clustering.

The other limitation is that we have to know k in advance to make the algorithm work.  Sometimes we know what we’re looking for, but usually, we don’t.  That means we often end up having to run clustering many times, measuring how much data we capture.

If we have a dataset we want to cluster, the first step is to convert it to a vector class.   MLlib offers two: Vectors.dense and Vectors.sparse.  The latter is very good one-hot encoding (is_red, is_blue, is_green, etc), and especially for encoding text vectors, such as tfidf or word2vec.  I’ll talk in another post about how to vectorize text.

Instead, let’s just use Vectors.dense and we can use a dataset near and dear to R users: mtcars.  It’s one of R’s standard example datasets which gives some statistics on a few different models of cars.  We can extract mtcars from R by using write.csv, and we can use the file as mtcars.csv, but we’ll remove the header row for simplicity.  Of course, it’d be silly to use Spark for such a tiny dataset (as we could easily just use R), but it serves the purpose of an example.

Great. Now we have some vectors.   We had to drop the name associated with each car and so our vectors are nameless — more on that later.

Now we need to make a KMeansModel object.   This may seem strange at first glance since in R and Mahout, there’s no model associated with K-means since there’s no training involved in an unsupervised ML algorithm.  Probably for the sake of consistency, MLlib treats Kmeans as a model that has to be “trained” with data, and then can be applied to new data using predict(), as if it was performing classification. While odd, this actually is a bonus because it easily allows us to use our clusters as a classification model for unseen data.

So clusters in this case is the KMeansModel object.  We chose a “K” value of 2, which probably isn’t going to get good results with this dataset.  How do we check that? We can use computecost()

The Spark documentation calls the cost WSSSE (Within Set Sum of Squared Errors).   Typically this should get better as k gets higher, but higher values of k may not produce very useful clusters (lots of clusters-of-one, for instance).

Intuitively, we should set k to be just before a point of inflection wherein the law of diminishing returns sets in, sometimes called the “elbow method.” But we should also look at where we start getting lots of small and meaningless clusters.

So now we have a KMeansModel set with our value of k. What does that give us? It assigns a number to each cluster (in the case of k=2, then just 0,1), but remember that we dropped the name for each vector. So we know which vector is in each cluster, but how do we relate this to the original data?  As I’ve done this exercise in Mahout, I was looking for the NamedVector class, which unfortunately doesn’t exist in Spark. The Spark team apparently doesn’t feel one is needed.

In Spark, the right way to do this is to join back the vector to the original data. To do that, we need to create a pair of names and vectors.

So that gives us our clustering results. As we said before, we can call predict() on new data that we might have, to see which cluster it would correspond to.

The new data doesn’t actually change the model, however. That’s frozen in time forever until we train a new one. There is, however, another class called StreamingKmeans, which will actually adjust clusters to new data, so we can use it in a streaming fashion. We’ll talk about that another time.

Leave a Reply

Your email address will not be published. Required fields are marked *