Matplotlib plot zooming with scroll wheel

2020-01-31 00:45发布

问题:

Is it possible to bind the scroll wheel to zoom in/out when the cursor is hovering over a matplotlib plot?

回答1:

This should work. It re-centers the graph on the location of the pointer when you scroll.

import matplotlib.pyplot as plt


def zoom_factory(ax,base_scale = 2.):
    def zoom_fun(event):
        # get the current x and y limits
        cur_xlim = ax.get_xlim()
        cur_ylim = ax.get_ylim()
        cur_xrange = (cur_xlim[1] - cur_xlim[0])*.5
        cur_yrange = (cur_ylim[1] - cur_ylim[0])*.5
        xdata = event.xdata # get event x location
        ydata = event.ydata # get event y location
        if event.button == 'up':
            # deal with zoom in
            scale_factor = 1/base_scale
        elif event.button == 'down':
            # deal with zoom out
            scale_factor = base_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print event.button
        # set new limits
        ax.set_xlim([xdata - cur_xrange*scale_factor,
                     xdata + cur_xrange*scale_factor])
        ax.set_ylim([ydata - cur_yrange*scale_factor,
                     ydata + cur_yrange*scale_factor])
        plt.draw() # force re-draw

    fig = ax.get_figure() # get the figure of interest
    # attach the call back
    fig.canvas.mpl_connect('scroll_event',zoom_fun)

    #return the function
    return zoom_fun

Assuming you have an axis object ax

 ax.plot(range(10))
 scale = 1.5
 f = zoom_factory(ax,base_scale = scale)

The optional argument base_scale allows you to set the scale factor to be what ever you want.

make sure you keep a copy of f around. The call back uses a weak-ref so if you do not keep a copy of f it might be garbage collected.

After writing this answer I decided this actually quite useful and put it in a gist



回答2:

Thanks guys, the examples were very helpful. I had to make a few changes to work with a scatter plot and I added panning with a left button drag. Hopefully someone will find this useful.

from matplotlib.pyplot import figure, show
import numpy

class ZoomPan:
    def __init__(self):
        self.press = None
        self.cur_xlim = None
        self.cur_ylim = None
        self.x0 = None
        self.y0 = None
        self.x1 = None
        self.y1 = None
        self.xpress = None
        self.ypress = None


    def zoom_factory(self, ax, base_scale = 2.):
        def zoom(event):
            cur_xlim = ax.get_xlim()
            cur_ylim = ax.get_ylim()

            xdata = event.xdata # get event x location
            ydata = event.ydata # get event y location

            if event.button == 'down':
                # deal with zoom in
                scale_factor = 1 / base_scale
            elif event.button == 'up':
                # deal with zoom out
                scale_factor = base_scale
            else:
                # deal with something that should never happen
                scale_factor = 1
                print event.button

            new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
            new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

            relx = (cur_xlim[1] - xdata)/(cur_xlim[1] - cur_xlim[0])
            rely = (cur_ylim[1] - ydata)/(cur_ylim[1] - cur_ylim[0])

            ax.set_xlim([xdata - new_width * (1-relx), xdata + new_width * (relx)])
            ax.set_ylim([ydata - new_height * (1-rely), ydata + new_height * (rely)])
            ax.figure.canvas.draw()

        fig = ax.get_figure() # get the figure of interest
        fig.canvas.mpl_connect('scroll_event', zoom)

        return zoom

    def pan_factory(self, ax):
        def onPress(event):
            if event.inaxes != ax: return
            self.cur_xlim = ax.get_xlim()
            self.cur_ylim = ax.get_ylim()
            self.press = self.x0, self.y0, event.xdata, event.ydata
            self.x0, self.y0, self.xpress, self.ypress = self.press

        def onRelease(event):
            self.press = None
            ax.figure.canvas.draw()

        def onMotion(event):
            if self.press is None: return
            if event.inaxes != ax: return
            dx = event.xdata - self.xpress
            dy = event.ydata - self.ypress
            self.cur_xlim -= dx
            self.cur_ylim -= dy
            ax.set_xlim(self.cur_xlim)
            ax.set_ylim(self.cur_ylim)

            ax.figure.canvas.draw()

        fig = ax.get_figure() # get the figure of interest

        # attach the call back
        fig.canvas.mpl_connect('button_press_event',onPress)
        fig.canvas.mpl_connect('button_release_event',onRelease)
        fig.canvas.mpl_connect('motion_notify_event',onMotion)

        #return the function
        return onMotion


fig = figure()

ax = fig.add_subplot(111, xlim=(0,1), ylim=(0,1), autoscale_on=False)

ax.set_title('Click to zoom')
x,y,s,c = numpy.random.rand(4,200)
s *= 200

ax.scatter(x,y,s,c)
scale = 1.1
zp = ZoomPan()
figZoom = zp.zoom_factory(ax, base_scale = scale)
figPan = zp.pan_factory(ax)
show()


回答3:

def zoom(self, event, factor):
    curr_xlim = self.ax.get_xlim()
    curr_ylim = self.ax.get_ylim()

    new_width = (curr_xlim[1]-curr_ylim[0])*factor
    new_height= (curr_xlim[1]-curr_ylim[0])*factor

    relx = (curr_xlim[1]-event.xdata)/(curr_xlim[1]-curr_xlim[0])
    rely = (curr_ylim[1]-event.ydata)/(curr_ylim[1]-curr_ylim[0])

    self.ax.set_xlim([event.xdata-new_width*(1-relx),
                event.xdata+new_width*(relx)])
    self.ax.set_ylim([event.ydata-new_width*(1-rely),
                        event.ydata+new_width*(rely)])
    self.draw()

The purpose of this slightly altered code is to keep track of the position of the cursor relative to the new zoom center. This way, if you zoom in and out of the picture at points other than the center you remain on the same point.



回答4:

Thanks very much. This worked great. However, for plots where the scale is no longer linear (log plots for instance), this breaks down. I have written a new version for this. I hope it helps someone.

Basically, I zoom in the axes coordinates which are normalized to be [0,1]. So, if I zoom in by two in x, I want to now be in the [.25, .75] range. I have also added a feature to only zoom in x if you're directly above or below the x axis, and only in y if you're directly left or right to the y axis. If you don't need this, just set zoomx=True, and zoomy = True and ignore the if statements.

This reference is very useful for those who want to understand how matplotlib transforms between different coordinate systems: http://matplotlib.org/users/transforms_tutorial.html

This function is within an object that contains a pointer to the axes (self.ax).

def zoom(self,event):
    '''This function zooms the image upon scrolling the mouse wheel.
    Scrolling it in the plot zooms the plot. Scrolling above or below the
    plot scrolls the x axis. Scrolling to the left or the right of the plot
    scrolls the y axis. Where it is ambiguous nothing happens. 
    NOTE: If expanding figure to subplots, you will need to add an extra
    check to make sure you are not in any other plot. It is not clear how to
    go about this.
    Since we also want this to work in loglog plot, we work in axes
    coordinates and use the proper scaling transform to convert to data
    limits.'''

    x = event.x
    y = event.y

    #convert pixels to axes
    tranP2A = self.ax.transAxes.inverted().transform
    #convert axes to data limits
    tranA2D= self.ax.transLimits.inverted().transform
    #convert the scale (for log plots)
    tranSclA2D = self.ax.transScale.inverted().transform

    if event.button == 'down':
        # deal with zoom in
        scale_factor = self.zoom_scale
    elif event.button == 'up':
        # deal with zoom out
        scale_factor = 1 / self.zoom_scale
    else:
        # deal with something that should never happen
        scale_factor = 1

    #get my axes position to know where I am with respect to them
    xa,ya = tranP2A((x,y))
    zoomx = False
    zoomy = False 
    if(ya < 0):
        if(xa >= 0 and xa <= 1):
            zoomx = True
            zoomy = False
    elif(ya <= 1):
        if(xa <0): 
            zoomx = False
            zoomy = True
        elif(xa <= 1):
            zoomx = True
            zoomy = True
        else:
            zoomx = False
            zoomy = True
    else:
        if(xa >=0 and xa <= 1):
            zoomx = True
            zoomy = False

    new_alimx = (0,1)
    new_alimy = (0,1)
    if(zoomx):
        new_alimx = (np.array([1,1]) + np.array([-1,1])*scale_factor)*.5
    if(zoomy):
        new_alimy = (np.array([1,1]) + np.array([-1,1])*scale_factor)*.5

    #now convert axes to data
    new_xlim0,new_ylim0 = tranSclA2D(tranA2D((new_alimx[0],new_alimy[0])))
    new_xlim1,new_ylim1 = tranSclA2D(tranA2D((new_alimx[1],new_alimy[1])))

    #and set limits
    self.ax.set_xlim([new_xlim0,new_xlim1])
    self.ax.set_ylim([new_ylim0,new_ylim1])
    self.redraw()


回答5:

I really like the "x only" or "y only" modes in the figure plots. You can bind the x and y keys so that zooming only happens in one direction. Note that you may also have to place the focus back on the canvas if you click on a Entry box or something -

canvas.mpl_connect('button_press_event', lambda event:canvas._tkcanvas.focus_set())

The rest of the modified code is below:

from matplotlib.pyplot import figure, show
import numpy

class ZoomPan:
    def __init__(self):
        self.press = None
        self.cur_xlim = None
        self.cur_ylim = None
        self.x0 = None
        self.y0 = None
        self.x1 = None
        self.y1 = None
        self.xpress = None
        self.ypress = None
        self.xzoom = True
        self.yzoom = True
        self.cidBP = None
        self.cidBR = None
        self.cidBM = None
        self.cidKeyP = None
        self.cidKeyR = None
        self.cidScroll = None

    def zoom_factory(self, ax, base_scale = 2.):
        def zoom(event):
            cur_xlim = ax.get_xlim()
            cur_ylim = ax.get_ylim()

            xdata = event.xdata # get event x location
            ydata = event.ydata # get event y location
            if(xdata is None):
                return()
            if(ydata is None):
                return()

            if event.button == 'down':
                # deal with zoom in
                scale_factor = 1 / base_scale
            elif event.button == 'up':
                # deal with zoom out
                scale_factor = base_scale
            else:
                # deal with something that should never happen
                scale_factor = 1
                print(event.button)

            new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
            new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor

            relx = (cur_xlim[1] - xdata)/(cur_xlim[1] - cur_xlim[0])
            rely = (cur_ylim[1] - ydata)/(cur_ylim[1] - cur_ylim[0])

            if(self.xzoom):
                ax.set_xlim([xdata - new_width * (1-relx), xdata + new_width * (relx)])
            if(self.yzoom):
                ax.set_ylim([ydata - new_height * (1-rely), ydata + new_height * (rely)])
            ax.figure.canvas.draw()
            ax.figure.canvas.flush_events()

        def onKeyPress(event):
            if event.key == 'x':
                self.xzoom = True
                self.yzoom = False
            if event.key == 'y':
                self.xzoom = False
                self.yzoom = True

        def onKeyRelease(event):
            self.xzoom = True
            self.yzoom = True

        fig = ax.get_figure() # get the figure of interest

        self.cidScroll = fig.canvas.mpl_connect('scroll_event', zoom)
        self.cidKeyP = fig.canvas.mpl_connect('key_press_event',onKeyPress)
        self.cidKeyR = fig.canvas.mpl_connect('key_release_event',onKeyRelease)

        return zoom

    def pan_factory(self, ax):
        def onPress(event):
            if event.inaxes != ax: return
            self.cur_xlim = ax.get_xlim()
            self.cur_ylim = ax.get_ylim()
            self.press = self.x0, self.y0, event.xdata, event.ydata
            self.x0, self.y0, self.xpress, self.ypress = self.press


        def onRelease(event):
            self.press = None
            ax.figure.canvas.draw()

        def onMotion(event):
            if self.press is None: return
            if event.inaxes != ax: return
            dx = event.xdata - self.xpress
            dy = event.ydata - self.ypress
            self.cur_xlim -= dx
            self.cur_ylim -= dy
            ax.set_xlim(self.cur_xlim)
            ax.set_ylim(self.cur_ylim)

            ax.figure.canvas.draw()
            ax.figure.canvas.flush_events()

        fig = ax.get_figure() # get the figure of interest

        self.cidBP = fig.canvas.mpl_connect('button_press_event',onPress)
        self.cidBR = fig.canvas.mpl_connect('button_release_event',onRelease)
        self.cidBM = fig.canvas.mpl_connect('motion_notify_event',onMotion)
        # attach the call back

        #return the function
        return onMotion


回答6:

This is a suggestion for a slight modification to the code above - it makes keeping the zoom centred more manageable.

    cur_xrange = (cur_xlim[1] - cur_xlim[0])*.5
    cur_yrange = (cur_ylim[1] - cur_ylim[0])*.5
    xmouse = event.xdata # get event x location                                                                                                                                                                                                                            
    ymouse = event.ydata # get event y location                                                                                                                                                                                                                            
    cur_xcentre = (cur_xlim[1] + cur_xlim[0])*.5
    cur_ycentre = (cur_ylim[1] + cur_ylim[0])*.5
    xdata = cur_xcentre+ 0.25*(xmouse-cur_xcentre)
    ydata = cur_ycentre+ 0.25*(ymouse-cur_ycentre)


回答7:

make tacaswell 's answer 'smooth'

def zoom_factory(ax, base_scale=2.):
    prex = 0
    prey = 0
    prexdata = 0
    preydata = 0

    def zoom_fun(event):
        nonlocal prex, prey, prexdata, preydata
        curx = event.x
        cury = event.y

        # if not changed mouse position(or changed so little)
        # remain the pre scale center
        if abs(curx - prex) < 10 and abs(cury - prey) < 10:
            # remain same
            xdata = prexdata
            ydata = preydata
        # if changed mouse position ,also change the cur scale center
        else:
            # change
            xdata = event.xdata  # get event x location
            ydata = event.ydata  # get event y location

            # update previous location data
            prex = event.x
            prey = event.y
            prexdata = xdata
            preydata = ydata

        # get the current x and y limits
        cur_xlim = ax.get_xlim()
        cur_ylim = ax.get_ylim()

        cur_xrange = (cur_xlim[1] - cur_xlim[0]) * .5
        cur_yrange = (cur_ylim[1] - cur_ylim[0]) * .5

        # log.debug((xdata, ydata))
        if event.button == 'up':
            # deal with zoom in
            scale_factor = 1 / base_scale
        elif event.button == 'down':
            # deal with zoom out
            scale_factor = base_scale
        else:
            # deal with something that should never happen
            scale_factor = 1
            print(event.button)
        # set new limits
        ax.set_xlim([
            xdata - cur_xrange * scale_factor,
            xdata + cur_xrange * scale_factor
        ])
        ax.set_ylim([
            ydata - cur_yrange * scale_factor,
            ydata + cur_yrange * scale_factor
        ])
        plt.draw()  # force re-draw

    fig = ax.get_figure()  # get the figure of interest
    # attach the call back
    fig.canvas.mpl_connect('scroll_event', zoom_fun)

    # return the function
    return zoom_fun


回答8:

The other answers using ax.set_xlim() and ax.set_ylim() did not give a satisfactory user experience for figures where setting the axes is slow. (for me this was an axes with a pcolormesh) The method ax.drag_pan() is much faster, and I believe it is more suited for most cases:

def mousewheel_move( event):
    ax=event.inaxes
    ax._pan_start = types.SimpleNamespace(
            lim=ax.viewLim.frozen(),
            trans=ax.transData.frozen(),
            trans_inverse=ax.transData.inverted().frozen(),
            bbox=ax.bbox.frozen(),
            x=event.x,
            y=event.y)
    if event.button == 'up':
        ax.drag_pan(3, event.key, event.x+10, event.y+10)
    else: #event.button == 'down':
        ax.drag_pan(3, event.key, event.x-10, event.y-10)
    fig=ax.get_figure()
    fig.canvas.draw_idle()

Then connect your figure with:

fig.canvas.mpl_connect('scroll_event',mousewheel_move)

Tested with matplotlib 3.0.2 using the the TkAgg backend and python 3.6