Animating a network graph to show the progress of

2019-05-07 16:19发布

问题:

I would like to animate a network graph to show the progress of an algorithm. I am using NetworkX for graph creation.

From this SO answer, I came up with a solution using clear_ouput from IPython.display and the command plt.pause() to manage the speed of the animation. This works well for small graphs with a few nodes but when I implement on a 10x10 grid, the animation is very slow and reducing the argument in plt.pause() does not seem to have any effect on the animation speed. Here is a MME with an implementation of Dijkstra's algorithm where I update the colors of the nodes at each iteration of the algorithm:

import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt
from IPython.display import clear_output
%matplotlib inline

# plotting function
def get_fig(G,current,pred): 
    nColorList = []
    for i in G.nodes():        
        if i == current: nColorList.append('red')
        elif i==pred: nColorList.append('white')
        elif i==N: nColorList.append('grey')        
        elif node_visited[i]==1:nColorList.append('dodgerblue')
        else: nColorList.append('powderblue')
    plt.figure(figsize=(10,10))
    nx.draw_networkx(G,pos,node_color=nColorList,width=2,node_size=400,font_size=10)
    plt.axis('off')
    plt.show()

# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):
    x= i % 10
    y= math.floor(i/10)
    pos[i]=(x,y)    
    if i % 10 != 9 and i+1 < 100: 
       cost[(i,i+1)] = random.randint(0,9)
       cost[(i+1,i)] = random.randint(0,9)
    if i+10 < 100: 
       cost[(i,i+10)] = random.randint(0,9)
       cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)   

# algorithm initialization
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
for i in G.nodes():       
    if i == N: lab[i] = 0        
    else: lab[i] = 9999
    path[i] = None
    node_visited[i] = 0 

# algorithm main loop    
while not SE.empty():
    (l,j) = SE.get()    
    if node_visited[j]==1: continue
    node_visited[j] = 1
    for i in G.predecessors(j):        
        insert_in_SE = 0               
        if lab[i] > cost[(i,j)] + lab[j]:
            lab[i] = cost[(i,j)] + lab[j]
            path[i] = j
            SE.put((lab[i],i))
        clear_output(wait=True)         
        get_fig(G,j,i)
        plt.pause(0.0001)
print('end')

Ideally I would like to show the whole animation in no more than 5 seconds, whereas it currently takes a few minutes to complete the algorithm, which suggests that plt.pause(0.0001) does not work as intended.

After reading SO posts on graph animation (post 2 and post 3), it seems that the animation module from matplotlib could be used to resolve this but I have not been able to successfully implement the answers in my algorithm. The answer in post 2 suggests the use of FuncAnimation from matplotlib but I am struggling to adapt the update method to my problem and the answer in post 3 leads to a nice tutorial with a similar suggestion.

My question is how can I improve the speed of the animation for my problem: is it possible to arrange the clear_output and plt.pause() commands for faster animation or should I use FuncAnimation from matplotlib? If it's the latter, then how should I define the update function?

Thank you for your help.

EDIT 1

import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt

# plotting function
def get_fig(G,current,pred):   
    for i in G.nodes():        
        if i==current: G.node[i]['draw'].set_color('red')            
        elif i==pred: G.node[i]['draw'].set_color('white')
        elif i==N: G.node[i]['draw'].set_color('grey')        
        elif node_visited[i]==1: G.node[i]['draw'].set_color('dodgerblue')
        else: G.node[i]['draw'].set_color('powderblue')    

# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):
    x= i % 10
    y= math.floor(i/10)
    pos[i]=(x,y)    
    if i % 10 != 9 and i+1 < 100: 
        cost[(i,i+1)] = random.randint(0,9)
        cost[(i+1,i)] = random.randint(0,9)
    if i+10 < 100: 
        cost[(i,i+10)] = random.randint(0,9)
        cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)

# algorithm initialization
plt.figure(1, figsize=(10,10))
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
for i in G.nodes():       
    if i == N: lab[i] = 0        
    else: lab[i] = 9999
    path[i] = None
    node_visited[i] = 0 
    G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=1,with_labels=True,node_color='powderblue')
for i,j in G.edges():
    G[i][j]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(i,j)],width=2)    

plt.ion()
plt.draw()
plt.show()

# algorithm main loop  
while not SE.empty():
    (l,j) = SE.get()    
    if node_visited[j]==1: continue
    node_visited[j] = 1
    for i in G.predecessors(j):        
        insert_in_SE = 0               
        if lab[i] > cost[(i,j)] + lab[j]:
            lab[i] = cost[(i,j)] + lab[j]
            path[i] = j
            SE.put((lab[i],i))       
        get_fig(G,j,i)        
        plt.draw()
        plt.pause(0.00001)
plt.close()

EDIT 2

import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt

# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):
    x= i % 10
    y= math.floor(i/10)
    pos[i]=(x,y)    
    if i % 10 != 9 and i+1 < 100: 
        cost[(i,i+1)] = random.randint(0,9)
        cost[(i+1,i)] = random.randint(0,9)
    if i+10 < 100: 
        cost[(i,i+10)] = random.randint(0,9)
        cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)

# algorithm initialization
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
cf = plt.figure(1, figsize=(10,10))    
ax = cf.add_axes((0,0,1,1))
for i in G.nodes():       
    if i == N: 
        lab[i] = 0
        G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=1.0,node_color='grey')
    else: 
        lab[i] = 9999
        G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=0.2,node_color='dodgerblue')
    path[i] = None
    node_visited[i] = 0
for i,j in G.edges():
    G[i][j]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(i,j)],width=3,alpha=0.2,arrows=False)

plt.ion()
plt.show()
ax = plt.gca()
canvas = ax.figure.canvas
background = canvas.copy_from_bbox(ax.bbox)

# algorithm main loop  
while not SE.empty():
    (l,j) = SE.get()    
    if node_visited[j]==1: continue
    node_visited[j] = 1
    if j!=N:
        G.node[j]['draw'].set_color('r')        
    for i in G.predecessors(j):        
        insert_in_SE = 0               
        if lab[i] > cost[(i,j)] + lab[j]:
            lab[i] = cost[(i,j)] + lab[j]
            path[i] = j
            SE.put((lab[i],i))
        if i!=N:            
            G.node[i]['draw'].set_alpha(0.7)
            G[i][j]['draw'].set_alpha(1.0)
        ax.draw_artist(G[i][j]['draw'])
        ax.draw_artist(G.node[i]['draw'])
        ax.draw_artist(G.node[j]['draw'])
        canvas.blit(ax.bbox)    
        plt.pause(0.0001)
plt.close()

回答1:

If your graph isn't too big you could try the following approach that sets the properties for individual nodes and edges. The trick is to save the output of the drawing functions which gives you a handle to the object properties like color, transparency, and visibility.

import networkx as nx
import matplotlib.pyplot as plt

G = nx.cycle_graph(12)
pos = nx.spring_layout(G)

cf = plt.figure(1, figsize=(8,8))
ax = cf.add_axes((0,0,1,1))

for n in G:
    G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.5,node_color='r')
for u,v in G.edges():
    G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5)

plt.ion()
plt.draw()

sp = nx.shortest_path(G,0,6)
edges = zip(sp[:-1],sp[1:])

for u,v in edges:
    plt.pause(1)
    G.node[u]['draw'].set_color('r')
    G.node[v]['draw'].set_color('r')
    G[u][v]['draw'].set_alpha(1.0)
    G[u][v]['draw'].set_color('r')
    plt.draw()

EDIT

Here is an example on a 10x10 grid using graphviz to do the layout. The whole thing runs in about 1 second on my machine.

import networkx as nx
import matplotlib.pyplot as plt

G = nx.grid_2d_graph(10,10)
pos = nx.graphviz_layout(G)

cf = plt.figure(1, figsize=(8,8))
ax = cf.add_axes((0,0,1,1))

for n in G:
    G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.5,node_color='k')
for u,v in G.edges():
    G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5)

plt.ion()
plt.draw()
plt.show()
sp = nx.shortest_path(G,(0,0),(9,9))
edges = zip(sp[:-1],sp[1:])

for u,v in edges:
    G.node[u]['draw'].set_color('r')
    G.node[v]['draw'].set_color('r')
    G[u][v]['draw'].set_alpha(1.0)
    G[u][v]['draw'].set_color('r')
    plt.draw()

EDIT 2

Here is another approach that is faster (doesn't redraw axis or all nodes) and uses a breadth first search algorithm. This one runs in about 2 seconds on my machine. I noticed that some backends are faster - I'm using GTKAgg.

import networkx as nx
import matplotlib.pyplot as plt

def single_source_shortest_path(G,source):
    ax = plt.gca()
    canvas = ax.figure.canvas
    background = canvas.copy_from_bbox(ax.bbox)
    level=0                  # the current level
    nextlevel={source:1}       # list of nodes to check at next level
    paths={source:[source]}  # paths dictionary  (paths to key from source)
    G.node[source]['draw'].set_color('r')
    G.node[source]['draw'].set_alpha('1.0')
    while nextlevel:
        thislevel=nextlevel
        nextlevel={}
        for v in thislevel:
#            canvas.restore_region(background)
            s = G.node[v]['draw']
            s.set_color('r')
            s.set_alpha('1.0')
            for w in G[v]:
                if w not in paths:
                    n = G.node[w]['draw']
                    n.set_color('r')
                    n.set_alpha('1.0')
                    e = G[v][w]['draw']
                    e.set_alpha(1.0)
                    e.set_color('k')
                    ax.draw_artist(e)
                    ax.draw_artist(n)
                    ax.draw_artist(s)
                    paths[w]=paths[v]+[w]
                    nextlevel[w]=1
                    canvas.blit(ax.bbox)
        level=level+1
    return paths



if __name__=='__main__':

    G = nx.grid_2d_graph(10,10)
    pos = nx.graphviz_layout(G)
    cf = plt.figure(1, figsize=(8,8))
    ax = cf.add_axes((0,0,1,1))

    for n in G:
        G.node[n]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[n], with_labels=False,node_size=200,alpha=0.2,node_color='k')
    for u,v in G.edges():
        G[u][v]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(u,v)],alpha=0.5,arrows=False,width=5)
    plt.ion()
    plt.show()

    path = single_source_shortest_path(G,source=(0,0))