# Better Heatmaps and Correlation Matrix Plots in Python

Correlation Matrix plots

[Note: You can also read this post on Medium, where you can clap if you like it]

You already know that if you have a data set with many columns, a good way to quickly check correlations among columns is by visualizing the correlation matrix as a heatmap.

But is a simple heatmap the best way to do it?

For illustration, I’ll use the Automobile Data Set, containing various characteristics of a number of cars. You can also find a clean version of the data with header columns here.

Let’s start by making a correlation matrix heatmap for the data set.

Great! Green means positive, red means negative. The stronger the color, the larger the correlation magnitude. Now looking at the chart above, think about the following questions:

• Where do your eyes jump first when you look at the chart?
• What’s the strongest and what’s the weakest correlated pair (except the main diagonal)?
• What are the three variables most correlated with price?

If you’re like most people, you’ll find it hard to map the color scale to numbers and vice versa.

Distinguishing positive from negative is easy, as well as 0 from 1. But what about the second question? Finding the highest negative and positive correlations mean finding the strongest red and green. To do that I need to carefully scan the entire grid. Try to answer it again and notice how your eyes are jumping around the plot, and sometimes going to the legend.

Now consider the following plot:

In addition to color, we’ve added size as a parameter to our heatmap. The size of each square corresponds to the magnitude of the correlation it represents, that is

size(c1, c2) ~ abs(corr(c1, c2))

Now try to answer the questions using the latter plot. Notice how weak correlations visually disappear, and your eyes are immediately drawn to areas where there’s high correlation. Also note that it’s now easier to compare magnitudes of negative vs positive values (lighter red vs lighter green), and we can also compare values that are further apart.

If we’re mapping magnitudes, it’s much more natural to link them to the size of the representing object than to its color. That’s exactly why on bar charts you would use height to display measures, and colors to display categories, but not vice versa.

Discrete Joint Distributions

Let’s see how the cars in our data set are distributed according to horsepower and drivetrain layout. That is, we want to visualize the following table

drive-wheels →
horsepower ↓
4wdfwdrwd
Low (0-100)58915
Medium (100-150)32435
High (150+)0525

Consider the following two ways to do it

The second version, where we use square size to display counts makes it effortless to determine which group is the largest/smallest. It also gives some intuition about the marginal distributions, all without needing to refer to a color legend.

Great. So how do I make these plots?

To make a regular heatmap, we simply used the Seaborn heatmap function, with a bit of additional styling.

For the second kind, there’s no trivial way to make it using matplotlib or seaborn. We could use corrplot from biokit, but it helps with correlations only and isn’t very useful for two-dimensional distributions.

Building a robust parametrized function that enables us to make heatmaps with sized markers is a nice exercise in matplotlib, so I’ll show you how to do it step by step.

We’ll start by using a simple scatter plot with squares as markers. Then we’ll fix some issues with it, add color and size as parameters, make it more general and robust to various types of input, and finally make a wrapper function corrplot that takes a result of DataFrame.corr method and plots a correlation matrix, supplying all the necessary parameters to the more general heatmap function.

It’s just a scatter plot

If we want to plot elements on a grid made by two categorical axes, we can use a scatter plot.

Looks like we’re onto something. But I said it’s just a scatterplot, and there’s quite a lot happening in the previous code snippet.

Since the scatterplot requires x and y to be numeric arrays, we need to map our column names to numbers. And since we want our axis ticks to show column names instead of those numbers, we need to set custom ticks and ticklabels. Finally there’s code that loads the dataset, selects a subset of columns, calculates all the correlations, melts the data frame (the inverse of creating a pivot table) and feeds its columns to our heatmap function.

You noticed that our squares are placed where our gridlines intersect, instead of being centered in their cells. In order to move the squares to cell centers, we’ll actually move the grid. And to move the grid, we’ll actually turn off major gridlines, and set minor gridlines to go right in between our axis ticks.

That’s better. But now the left and bottom side look cropped. That’s because our axis lower limit are set to 0. We’ll sort this out by setting the lower limit for both axes to – 0.5. Remember, our points are displayed at integer coordinates, so our gridlines are at .5 coordinates.

Give it some color

Now comes the fun part. We need to map the possible range of values for correlation coefficients, [-1, 1], to a color palette. We’ll use a diverging palette, going from red for -1, all the way to green for 1. Looking at Seaborn color palettes, seems that we’ll do just fine with something like

sns.palplot(sns.diverging_palette(220, 20, n=7))

But lets first flip the order of colors and make it smoother by adding more steps between red and green:

palette = sns.diverging_palette(20, 220, n=256)

Seaborn color palettes are just arrays of color components, so in order to map a correlation value to the appropriate color, we need to ultimately map it to an index in the palette array. It’s a simple mapping of one interval to another: [-1, 1] [0, 1] (0, 255).

v ∈ [val_min, val_max]

t = (vval_min) / (val_max – val_min)

t ∈ [ 0.0, 1.0 ]

ind = round(t) * 255

ind ∈ (0 1 2 … … … 254 255)

Just what we wanted. Let’s now add a color bar on the right side of the chart. We’ll use GridSpec to set up a plot grid with 1 row and n columns. Then we’ll use the rightmost column of the plot to display the color bar and the rest to display the heatmap.

There are multiple ways to display a color bar, here we’ll trick our eyes by using a really dense bar chart. We’ll draw n_colors horizontal bars, each colored with its respective color from the palette.

And we have our color bar.

We’re almost done. Now we should just flip the vertical axis so that we get correlation of each variable with itself shown on the main diagonal, make squares a bit larger and make the background a just a tad lighter so that values around 0 are more visible.

But let’s first make the entire code more useful.

More parameters!

It would be great if we made our function able to accept more than just a correlation matrix. To do this we’ll make the following changes:

• Be able to pass color_min, color_max and size_min, size_max as parameters so that we can map different ranges than [-1, 1] to color and size. This will enable us to use the heatmap beyond correlations
• Use a sequential palette if no palette specified, use a single color if no color vector provided
• Use a constant size if no size vector provided. Avoid mapping the lowest value to 0 size.
• Make x and y the only necessary parameters, and pass size, color, size_scale, size_range, color_range, palette, marker as kwargs. Provide sensible defaults for each of the parameters
• Use list comprehensions instead pandas apply and map methods, so we can pass any kind of arrays as x, y, color, size instead of just pandas.Series
• Pass any other kwargs to pyplot.scatterplot function
• Make a wrapper function corrplot that accepts a corr() dataframe, melts it, calls heatmap with a red-green diverging color palette, and size/color min-max set to [-1, 1]

That’s quite a lot of boilerplate stuff to cover step by step, so here’s what it looks like when done. You can also check it out in this Kaggle kernel.

Now that we have our corrplot and heatmap functions, in order to create the correlation plot with sized squares, like the one at the beginning of this post, we simply do the following:

And just for fun, let’s make a plot showing how engine power is distributed among car brands in our data set.