If you have ever tried solving a Classification task using a Machine Learning (ML) algorithm, you might have heard of such a way to evaluate the model’s performance as a Confusion matrix. On this page, we will:
Let’s jump in.
A Сonfusion matrix is a table used to evaluate the accuracy of the ML model’s performance on a Classification task. Technically speaking, the matrix itself is not really a Machine Learning metric. It is more of a heuristic used as a basis for various metrics. However, many ML Classification metrics are calculated on top of the Confusion matrix as it is an excellent indicator of the model’s performance.
In detail, the Confusion matrix is the basis for such metrics as:
One of the substantial advantages of the Confusion matrix is its versatility. You can use the method both for the binary and multiclass Classification tasks. Let’s break these cases down one by one.
The Confusion matrix for the binary case looks as follows.
To calculate it, you need to compare the classifier’s predictions and the ground truth class labels. As you can see, the names of the rows and the columns in the table have the words 'Positive' and 'Negative' in them. These words indicate the classes (kind reminder, we are reviewing the binary case).
The Confusion matrix for the binary case derives four types of predictions. These are:
To calculate the matrix, you simply need to assign the number of the prediction/ground truth combinations to the table.
Strictly speaking, a Confusion matrix C is such that C (i, j) equals the number of observations known to be in group i and predicted to be in group j. So, the multiclass case simply expands the binary one. Let’s say we have three classes - 'Positive', ‘Negative', and 'Neutral'. In this case, the matrix will look as follows:
So, the only thing that changed was the addition of one more row and column. We would add two more rows and columns if we had four classes. It is as simple as that.
Also, as you can see, the multiclass case does not have a strict prediction/ground truth combinations terminology. Still, the matrix calculation process is the same - assign the number of the prediction/ground truth combinations to the table and be happy with yourself.
The Confusion matrix is not straightforward to analyze and interpret (especially the multiclass one). It is pretty helpful if you want to check the direct tradeoff between the False Positives and False Negatives for binary Classification, but beyond that, it is rather useless.
Still, as mentioned above, you can calculate more interpretable metrics with the Confusion matrix results. Check out the metrics based on the Confusion matrix section to learn more.
Let’s check out a simple example.
Imagine solving a binary Classification task. For example, you are trying to determine whether a cat or a dog is on an image. You have a model and want to evaluate its performance using the Confusion matrix. You pass 15 pictures with a cat and 20 images with a dog to the model. From the given 15 cat images, the algorithm predicts 9 pictures as the dog ones, and from the 20 dog images - 6 pictures as the cat ones. It is time to build a Confusion matrix.
The workflow is straightforward. Let’s say that the cat images are a Positive class, whereas the dog pictures are a Negative one:
Here is the matrix itself:
Ok, great. You decide to expand the task and add another class, for example, the bird one. You pass 15 pictures with a cat, 20 images with a dog, and 12 pictures with a bird to the model. The predictions are as follows:
Let’s build the matrix.
If you want to check whether your Confusion matrix is correct, simply summarize the values in each column. If everything is fine, you will get the initial number of class samples passed to the model. For example:
The Confusion matrix is widely used in the industry, so all the Machine and Deep Learning libraries have their own implementation of this measure. For this page, we prepared three code blocks featuring building Confusion matrix in Python. In detail, you can check out:
Scikit-learn is the most popular Python library for classical Machine Learning. From our experience, Sklearn is the tool you will likely use the most to build a Confusion matrix. Fortunately, you can do it in just a few lines of code.
# Importing the function from sklearn.metrics import confusion_matrix # Initializing the arrays (multiclass ones) y_true = [2, 0, 2, 2, 0, 1] y_pred = [0, 0, 2, 2, 0, 2] # Printing the result confusion_matrix(y_true, y_pred)
Beyond the basic functionality, Sklearn has various Confusion matrix options implemented. You should definitely check them out to simplify your workflow.
!pip install torchmetrics # Importing the library import torch import torchmetrics from torchmetrics import ConfusionMatrix # Initializing the input tensor target = torch.tensor([1, 1, 0, 0]) preds = torch.tensor([0, 1, 0, 0]) # Сalculating the matrix and printing the result confmat = ConfusionMatrix(num_classes=2) confmat(preds, target)
# Importing the library import tensorflow as tf # Initializing the input tensor labels = tf.constant([1,3,4],dtype = tf.int32) predictions = tf.constant([1,2,3],dtype = tf.int32) # Printing the input tensor print('labels: ',labels) print('Predictins: ',predictions) # Сalculating the matrix res = tf.math.confusion_matrix(labels,predictions) # Printing the result print('Confusion_matrix: ',res)