Machine Learning – Best Frameworks for Supervised Clustering

clusteringneural networkssupervised learning

I have a task where I need to take "data points" which consist of collections of items. Each item needs to be categorised according to predefined categories. That's the easy part – my solution is to train a deep neural network with cross entropy loss. By the way, the reason I don't classify each item separately is because they acquire their meaning when they come together as a set.

The hard part is that each of these items also have a cluster label. Each cluster can only have items of one category in it, and there can be any number of clusters. Unsupervised clustering methods (applied after the neural network does the categorisation) work fairly well, but not well-enough for my needs. I'd like to:

A. Make use of the fact that I have the ground truth labelling for these clusters

B. I'd like to leverage my deep neural network because a lot of the "reasoning" required to solve the classification task will be conducive to the clustering task.

Answers which address at least one of those are useful to me. Thanks!

EDIT

An (hopefully minimal) example:

The task: I have images with any number of cats and dogs in them. The dogs and cats tend to hang out in groups. I already have unlabelled bounding boxes around each animal. I need to

  1. Categorize each animal as either a cat or a dog (the "easy part")
  2. Cluster the cats into groups, and the dogs into groups. So maybe there are 3 cats hanging out by a garbage bin. I might say they are a cluster. Then maybe there are 4 other cats playing with a ball of yarn – they belong to another cluster. To signify a cluster, I can give each cat (or dog) within a cluster the same label. Doesn't matter what the label is.

The learning framework: I have a training dataset of images. For each image I have:

  1. Labels for each bounding box (cat or dog). So it's just a standard classification task – "the easy part"
  2. Labels for each cluster. They are permutation invariant of course. It doesn't matter if I swap all xs for ys within an image. This is "the hard part". The bit I'm asking help for. Also notice that a deep neural network which is able to solve 1, can probably be reused for solving this problem.

As an aside, my minimal example is missing a pretty important (but not critical to this question) detail. Somehow when you look at a cat on its own, it's hard to tell that it's a cat. You need to look at them as an ensemble to know they are cats. I didn't know how to work this into my example. But the detail is important because it explains why I think 2 should be learned together with 1. A network which can learn 1 probably contains a lot of the "reasoning" required to learn 2. This explains my motivation for B above.

Best Answer

My solution was to use a graph to represent relationships between objects in the scene. Objects are all nodes, and two nodes are connected if they belong in the same cluster. In terms of a matrix, this is an NxN matrix with 1s where two objects are connected and 0s otherwise. The matrix is then symmetric and all diagonals are ones. I used a graph neural network, and the edges of the graph were regressed to the ground truth connectivity matrix via binary cross entropy loss. It worked very well.

Related Question