Prototypes

What is a prototype?

Prototype generation is a global interpretability method. It provides insight into what a model has learned without looking at its performance on test data, by extracting learned features directly from the model itself. This is important, because there's no guarantee that your test data covers all potential failure modes. It's another way of understanding what your model has learned, and helping you to predict how it will behave in deployment, on unseen data.

So what is a prototype? For each class that your model has been trained to predict, we can generate an input that maximises the probability of that output – this is the model's prototype for that class. It's a representation of what the model 'thinks' that class is.

For example, if you have a model trained to diagnose cancer from biopsy slides, prototype generation can show you what the model has learned to look for - what it 'thinks' malignant cells look like. This means you can check to see if it's looking for the right stuff, and ensure that it hasn't learned any spurious correlations from its training data that would cause dangerous mistakes in deployment (e.g. looking for lab markings on the slides, rather than at cell morphology).

This is an example from a model trained to classify food. We generated prototypes of the classes ice cream, hamburger, pancakes, waffles, and baklava:

These images are generated without any data set. With our algorithm, the model has output the features that maximize for the given output class. The features in these images are what the model most associates with these classes.

Looking at the prototypes for pancakes and waffles, notice the berries. This model has learned a spurious correlation from its training data. It thinks that berries are an essential part of these classes. This may lead to misclassification of any image containing berries, and you never would have known without this assessment.