Confusion Matrix Plotting – The Simplest Way

onestop_databy:

Data visualizationPython

This tutorial shows how to plot a confusion matrix in Python using a heatmap.

1. Confusion Matrix in Python

First and foremost, please see below how you can use Seaborn and Matplotlib to plot a heatmap.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import seaborn
import matplotlib.pyplot as plt


def plot_confusion_matrix(data, labels, output_filename):
    """Plot confusion matrix using heatmap.

    Args:
        data (list of list): List of lists with confusion matrix data.
        labels (list): Labels which will be plotted across x and y axis.
        output_filename (str): Path to output file.

    """
    seaborn.set(color_codes=True)
    plt.figure(1, figsize=(9, 6))

    plt.title("Confusion Matrix")

    seaborn.set(font_scale=1.4)
    ax = seaborn.heatmap(data, annot=True, cmap="YlGnBu", cbar_kws={'label': 'Scale'})

    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    ax.set(ylabel="True Label", xlabel="Predicted Label")

    plt.savefig(output_filename, bbox_inches='tight', dpi=300)
    plt.close()

# define data
data = [[13, 1, 1, 0, 2, 0],
        [3, 9, 6, 0, 1, 0],
        [0, 0, 16, 2, 0, 0],
        [0, 0, 0, 13, 0, 0],
        [0, 0, 0, 0, 15, 0],
        [0, 0, 1, 0, 0, 15]]

# define labels
labels = ["A", "B", "C", "D", "E", "F"]

# create confusion matrix
plot_confusion_matrix(data, labels, "confusion_matrix.png")

Now, we will be plotting the data in the following list of a list that could represent your matrix data.

# define data
data = [[13, 1, 1, 0, 2, 0],
        [3, 9, 6, 0, 1, 0],
        [0, 0, 16, 2, 0, 0],
        [0, 0, 0, 13, 0, 0],
        [0, 0, 0, 0, 15, 0],
        [0, 0, 1, 0, 0, 15]]

# define labels
labels = ["A", "B", "C", "D", "E", "F"]

# create confusion matrix
plot_confusion_matrix(data, labels, "confusion_matrix.png")

Last but not least, the plot should look like this and can be easily modified.

2. More Information about Confusion Matrix

Here is a very useful YouTube video in case you want to learn more about confusion matrix and when it should be used.

More Resources

Here are two of my favorite Data Visualization Python Books in case you want to learn more about it.

Conclusion

In summary, this tutorial showed you how to use seaborn to plot a heatmap matrix.

Related Posts