Changing width of heatmap in Seaborn to compensate

2019-06-04 22:57发布

问题:

I have a sentence like say

Hey I am feeling pretty boring today and the day is dull too

I pass it through the openai sentiment code which gives me some neuron weights which can be equal or little greater then number of words.

Neuron weights are

[ 0.01258736,  0.03544582,  0.08490616,  0.09010842,  0.07180552,
        0.07271874,  0.08906463,  0.09690772,  0.10281454,  0.08131664,
        0.08315734,  0.0790544 ,  0.07770097,  0.07302617,  0.07329235,
        0.06856266,  0.07642639,  0.08199468,  0.09079508,  0.09539193,
        0.09061056,  0.07109602,  0.02138061,  0.02364372,  0.00322057,
        0.01517018,  0.01150052,  0.00627739,  0.00445003,  0.00061127,
        0.0228037 , -0.29226044, -0.40493113, -0.4069235 , -0.39796737,
       -0.39871565, -0.39242673, -0.3537892 , -0.3779315 , -0.36448184,
       -0.36063945, -0.3506464 , -0.36719123, -0.37997353, -0.35103855,
       -0.34472692, -0.36256564, -0.35900915, -0.3619383 , -0.3532831 ,
       -0.35352525, -0.33328298, -0.32929575, -0.33149993, -0.32934144,
       -0.3261477 , -0.32421976, -0.3032671 , -0.47205922, -0.46902984,
       -0.45346943, -0.4518705 , -0.50997925, -0.50997925]

Now what I wanna do is plot a heatmap , the positive values shows positive sentiments while negative ones shows negative sentiment and I am plotting the heat map but the heatmap isn't plotting like it should be

But when the sentence gets longer the whole sentence gets smaller and smaller that can't be seen ,So what changes should I do to make it show better.

Here is my plotting function:

def plot_neuron_heatmap(text, values, savename=None, negate=False, cell_height=.112, cell_width=.92):
    #n_limit = 832
    cell_height=.325
    cell_width=.15
    n_limit = count
    num_chars = len(text)
    text = list(map(lambda x: x.replace('\n', '\\n'), text))
    num_chars = len(text)
    total_chars = math.ceil(num_chars/float(n_limit))*n_limit
    mask = np.array([0]*num_chars + [1]*(total_chars-num_chars))
    text = np.array(text+[' ']*(total_chars-num_chars))
    values = np.array((values+[0])*(total_chars-num_chars))

    values = values.reshape(-1, n_limit)
    text = text.reshape(-1, n_limit)
    mask = mask.reshape(-1, n_limit)
    num_rows = len(values)
    plt.figure(figsize=(cell_width*n_limit, cell_height*num_rows))
    hmap=sns.heatmap(values, annot=text, mask=mask, fmt='', vmin=-5, vmax=5, cmap='RdYlGn',xticklabels=False, yticklabels=False, cbar=False)
    plt.subplots_adjust() 
    #plt.tight_layout()
    plt.savefig('fig1.png')
    #plt.show()

This is how it shows the lengthy text as

What I want it to show

Here is a link to the full notebook: https://github.com/yashkumaratri/testrepo/blob/master/heatmap.ipynb

Mad Physicist , Your code does this and what really it should do is

回答1:

The shrinkage of the font you are seeing is to be expected. As you add more characters horizontally, the font shrinks to fit everything in. There are a couple of solutions for this. The simplest would be to break your text into smaller chunks, and display them as you show in your desired output. Also, you can print your figure with a different DPI with what is shown on the screen, so that the letters will look fine in the image file.

You should consider cleaning up your function along the way:

  1. count appears to be a global that is never used.
  2. You redefine variables without ever using the original value (e.g. num_chars and the input parameters).
  3. You have a whole bunch of variables you don't really use.
  4. You recompute a lot of quantities multiple times.
  5. The expression list(map(lambda x: x.replace('\n', '\\n'), text)) is total overkill: list(text.replace('\n', '\\n')) does the same thing.
  6. Given that len(values) != len(text) for most cases, the line values = np.array((values+[0])*(total_chars-num_chars)) is nonsense and needs cleanup.
  7. You are constructing numpy arrays by doing padding operations on lists, instead of using the power of numpy.
  8. You have the entire infrastructure for properly reshaping your arrays already in place, but you don't use it.

The updated version below fixes the minor issues and adds n_limit as a parameter, which determines how many characters you are willing to have in a row of the heat map. As I mentioned in the last item, you already have all the necessary code to reshape your arrays properly, and even mask out the extra tail you end up with sometimes. The only thing that is wrong is the -1 in the shape, which always resolves to one row because of the remainder of the shape. Additionally, the figure is always saved at 100dpi, so the results should come out consistent for a given width, no matter how many rows you end up with. The DPI affects PNG because it increases or decreases the total number of pixels in the image, and PNG does not actually understand DPI:

def plot_neuron_heatmap(text, values, n_limit=80, savename='fig1.png',
                        cell_height=0.325, cell_width=0.15, dpi=100):
    text = text.replace('\n', '\\n')
    text = np.array(list(text + ' ' * (-len(text) % n_limit)))
    if len(values) > text.size:
        values = np.array(values[:text.size])
    else:
        t = values
        values = np.zeros(text.shape, dtype=np.int)
        values[:len(t)] = t
    text = text.reshape(-1, n_limit)
    values = values.reshape(-1, n_limit)
    # mask = np.zeros(values.shape, dtype=np.bool)
    # mask.ravel()[values.size:] = True
    plt.figure(figsize=(cell_width * n_limit, cell_height * len(text)))
    hmap = sns.heatmap(values, annot=text, fmt='', vmin=-5, vmax=5, cmap='RdYlGn', xticklabels=False, yticklabels=False, cbar=False)
    plt.subplots_adjust()
    plt.savefig(savename if savename else 'fig1.png', dpi=dpi)

Here are a couple of sample runs of the function:

text = 'Hey I am feeling pretty boring today and the day is dull too'
values = [...] # The stuff in your question

plot_neuron_heatmap(text, values)
plot_neuron_heatmap(text, values, 20)
plot_neuron_heatmap(text, values, 7)

results in the following three figures: