Confusion Matrix Plotting – The Simplest Way

by:

Data visualizationPython

This tutorial shows how to plot a confusion matrix in Python using a heatmap.

1. What is a Confusion Matrix?

A confusion matrix is a table used to evaluate the performance of a classification model. It provides a summary of the model’s performance in terms of the number of true positive (TP), false positive (FP), true negative (TN), and false negative (FN) predictions. The confusion matrix is used to calculate several important metrics that help to evaluate the performance of a classification model, including accuracy, precision, recall, and F1 score.

Here is a description of the four elements of a confusion matrix:

  1. True Positive (TP): The number of cases in which the model correctly predicted the positive class.
  2. False Positive (FP): The number of cases in which the model predicted the positive class, but it was actually negative.
  3. True Negative (TN): The number of cases in which the model correctly predicted the negative class.
  4. False Negative (FN): The number of cases in which the model predicted the negative class, but it was actually positive.

The confusion matrix can be used to diagnose the strengths and weaknesses of a classification model and to compare different models. It is a useful tool for machine learning practitioners to evaluate the performance of their models and to make informed decisions about how to improve them.

2. Confusion Matrix in Python

First and foremost, please see below how you can use Seaborn and Matplotlib to plot a heatmap.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import seaborn
import matplotlib.pyplot as plt


def plot_confusion_matrix(data, labels, output_filename):
    """Plot confusion matrix using heatmap.

    Args:
        data (list of list): List of lists with confusion matrix data.
        labels (list): Labels which will be plotted across x and y axis.
        output_filename (str): Path to output file.

    """
    seaborn.set(color_codes=True)
    plt.figure(1, figsize=(9, 6))

    plt.title("Confusion Matrix")

    seaborn.set(font_scale=1.4)
    ax = seaborn.heatmap(data, annot=True, cmap="YlGnBu", cbar_kws={'label': 'Scale'})

    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    ax.set(ylabel="True Label", xlabel="Predicted Label")

    plt.savefig(output_filename, bbox_inches='tight', dpi=300)
    plt.close()

# define data
data = [[13, 1, 1, 0, 2, 0],
        [3, 9, 6, 0, 1, 0],
        [0, 0, 16, 2, 0, 0],
        [0, 0, 0, 13, 0, 0],
        [0, 0, 0, 0, 15, 0],
        [0, 0, 1, 0, 0, 15]]

# define labels
labels = ["A", "B", "C", "D", "E", "F"]

# create confusion matrix
plot_confusion_matrix(data, labels, "confusion_matrix.png")

Now, we will be plotting the data in the following list of a list that could represent your matrix data.

# define data
data = [[13, 1, 1, 0, 2, 0],
        [3, 9, 6, 0, 1, 0],
        [0, 0, 16, 2, 0, 0],
        [0, 0, 0, 13, 0, 0],
        [0, 0, 0, 0, 15, 0],
        [0, 0, 1, 0, 0, 15]]

# define labels
labels = ["A", "B", "C", "D", "E", "F"]

# create confusion matrix
plot_confusion_matrix(data, labels, "confusion_matrix.png")

Last but not least, the plot should look like this and can be easily modified.

2. More Information about Confusion Matrix

Here is a very useful YouTube video in case you want to learn more about confusion matrix and when it should be used.

More Resources

Here are two of my favorite Data Visualization Python Books in case you want to learn more about it.

Conclusion

In summary, this tutorial showed you how to use seaborn to plot a heatmap matrix.

Related Posts