from matplotlib import use as mplt_use
mplt_use('Agg')
from deeptools import cm # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import scipy.cluster.hierarchy as sch
from matplotlib import rcParams
import matplotlib.colors as pltcolors
import copy
rcParams['pdf.fonttype'] = 42
rcParams['svg.fonttype'] = 'none'
old_settings = np.seterr(all='ignore')
[docs]
def plot_correlation(corr_matrix, labels, plotFileName, vmax=None,
vmin=None, colormap='jet', image_format=None,
plot_numbers=False, plot_title=''):
num_rows = corr_matrix.shape[0]
# set a font size according to figure length
if num_rows < 6:
font_size = 14
elif num_rows > 40:
font_size = 5
else:
font_size = int(14 - 0.25 * num_rows)
rcParams.update({'font.size': font_size})
# set the minimum and maximum values
if vmax is None:
vmax = 1
if vmin is None:
vmin = 0 if corr_matrix.min() >= 0 else -1
# Compute and plot dendrogram.
fig = plt.figure(figsize=(11, 9.5))
if plot_title:
plt.suptitle(plot_title)
axdendro = fig.add_axes([0.02, 0.12, 0.1, 0.66])
axdendro.set_axis_off()
y_var = sch.linkage(corr_matrix, method='complete')
z_var = sch.dendrogram(y_var, orientation='right',
link_color_func=lambda k: 'darkred')
axdendro.set_xticks([])
axdendro.set_yticks([])
cmap = copy.copy(plt.get_cmap(colormap))
# this line simply makes a new cmap, based on the original
# colormap that goes from 0.0 to 0.9
# This is done to avoid colors that
# are too dark at the end of the range that do not offer
# a good contrast between the correlation numbers that are
# plotted on black.
if plot_numbers:
cmap = pltcolors.LinearSegmentedColormap.from_list(colormap + "clipped",
cmap(np.linspace(0, 0.9, 10)))
cmap.set_under((0., 0., 1.))
# Plot distance matrix.
axmatrix = fig.add_axes([0.13, 0.1, 0.6, 0.7])
index = z_var['leaves']
corr_matrix = corr_matrix[index, :]
corr_matrix = corr_matrix[:, index]
img_mat = axmatrix.pcolormesh(corr_matrix,
edgecolors='black',
cmap=cmap,
vmax=vmax,
vmin=vmin)
axmatrix.set_xlim(0, num_rows)
axmatrix.set_ylim(0, num_rows)
axmatrix.yaxis.tick_right()
axmatrix.set_yticks(np.arange(corr_matrix.shape[0]) + 0.5)
axmatrix.set_yticklabels(np.array(labels).astype('str')[index])
# axmatrix.xaxis.set_label_position('top')
axmatrix.xaxis.set_tick_params(labeltop=True)
axmatrix.xaxis.set_tick_params(labelbottom=False)
axmatrix.set_xticks(np.arange(corr_matrix.shape[0]) + 0.5)
axmatrix.set_xticklabels(np.array(labels).astype('str')[index],
rotation=45,
ha='left')
axmatrix.tick_params(
axis='x',
which='both',
bottom=False,
top=False)
axmatrix.tick_params(
axis='y',
which='both',
left=False,
right=False)
# axmatrix.set_xticks([])
# Plot colorbar.
axcolor = fig.add_axes([0.13, 0.065, 0.6, 0.02])
cobar = plt.colorbar(img_mat, cax=axcolor, orientation='horizontal')
cobar.solids.set_edgecolor("face")
if plot_numbers:
for row in range(num_rows):
for col in range(num_rows):
axmatrix.text(row + 0.5, col + 0.5,
"{:.2f}".format(corr_matrix[row, col]),
ha='center', va='center')
fig.savefig(plotFileName, format=image_format)
fig.close()