Issue
I want to make a clustermap/heatmap of gene presence-absence data from patients where the genes will be grouped into categories (e.g chemotaxis, endotoxin etc) and labelled appropriately. I haven't found any such option in seaborn documentation. I know how to generate the heatmap, I just don't know how to label yticks as categories. Here is a sample (unrelated to my work) of what I want to achieve:
Here , yticklabels January, February and March are given group label winter and other yticklabels are also similarly labelled.
Solution
I've reproduced the example you gave in seaborn, adapting @Stein's answer from here.
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from itertools import groupby
import datetime
import seaborn as sns
def test_table():
months = [datetime.date(2008, i+1, 1).strftime('%B') for i in range(12)]
seasons = ['Winter',]*3 + ['Spring',]*2 + ['Summer']*3 + ['Pre-Winter',]*4
tuples = list(zip(months, seasons))
index = pd.MultiIndex.from_tuples(tuples, names=['first', 'second'])
d = {i: [np.random.randint(0,50) for _ in range(12)] for i in range(1950, 1960)}
df = pd.DataFrame(d, index=index)
return df
def add_line(ax, xpos, ypos):
line = plt.Line2D([ypos, ypos+ .2], [xpos, xpos], color='black', transform=ax.transAxes)
line.set_clip_on(False)
ax.add_line(line)
def label_len(my_index,level):
labels = my_index.get_level_values(level)
return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
def label_group_bar_table(ax, df):
xpos = -.2
scale = 1./df.index.size
for level in range(df.index.nlevels):
pos = df.index.size
for label, rpos in label_len(df.index,level):
add_line(ax, pos*scale, xpos)
pos -= rpos
lypos = (pos + .5 * rpos)*scale
ax.text(xpos+.1, lypos, label, ha='center', transform=ax.transAxes)
add_line(ax, pos*scale , xpos)
xpos -= .2
df = test_table()
fig = plt.figure(figsize = (10, 10))
ax = fig.add_subplot(111)
sns.heatmap(df)
#Below 3 lines remove default labels
labels = ['' for item in ax.get_yticklabels()]
ax.set_yticklabels(labels)
ax.set_ylabel('')
label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.1*df.index.nlevels)
plt.show()
Gives:
Hope that helps.
Answered By - CDJB Answer Checked By - Robin (PHPFixing Admin)
0 Comments:
Post a Comment
Note: Only a member of this blog may post a comment.