I would like to animate a network graph to show the progress of an algorithm. I am using NetworkX for graph creation.
From this SO answer, I came up with a solution using clear_ouput
from IPython.display
and the command plt.pause()
to manage the speed of the animation. This works well for small graphs with a few nodes but when I implement on a 10x10 grid, the animation is very slow and reducing the argument in plt.pause()
does not seem to have any effect on the animation speed. Here is a MME with an implementation of Dijkstra's algorithm where I update the colors of the nodes at each iteration of the algorithm:
import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt
from IPython.display import clear_output
%matplotlib inline
# plotting function
def get_fig(G,current,pred):
nColorList = []
for i in G.nodes():
if i == current: nColorList.append('red')
elif i==pred: nColorList.append('white')
elif i==N: nColorList.append('grey')
elif node_visited[i]==1:nColorList.append('dodgerblue')
else: nColorList.append('powderblue')
plt.figure(figsize=(10,10))
nx.draw_networkx(G,pos,node_color=nColorList,width=2,node_size=400,font_size=10)
plt.axis('off')
plt.show()
# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):
x= i % 10
y= math.floor(i/10)
pos[i]=(x,y)
if i % 10 != 9 and i+1 < 100:
cost[(i,i+1)] = random.randint(0,9)
cost[(i+1,i)] = random.randint(0,9)
if i+10 < 100:
cost[(i,i+10)] = random.randint(0,9)
cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)
# algorithm initialization
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
for i in G.nodes():
if i == N: lab[i] = 0
else: lab[i] = 9999
path[i] = None
node_visited[i] = 0
# algorithm main loop
while not SE.empty():
(l,j) = SE.get()
if node_visited[j]==1: continue
node_visited[j] = 1
for i in G.predecessors(j):
insert_in_SE = 0
if lab[i] > cost[(i,j)] + lab[j]:
lab[i] = cost[(i,j)] + lab[j]
path[i] = j
SE.put((lab[i],i))
clear_output(wait=True)
get_fig(G,j,i)
plt.pause(0.0001)
print('end')
Ideally I would like to show the whole animation in no more than 5 seconds, whereas it currently takes a few minutes to complete the algorithm, which suggests that plt.pause(0.0001)
does not work as intended.
After reading SO posts on graph animation (post 2 and post 3), it seems that the animation
module from matplotlib could be used to resolve this but I have not been able to successfully implement the answers in my algorithm. The answer in post 2 suggests the use of FuncAnimation
from matplotlib but I am struggling to adapt the update
method to my problem and the answer in post 3 leads to a nice tutorial with a similar suggestion.
My question is how can I improve the speed of the animation for my problem: is it possible to arrange the clear_output
and plt.pause()
commands for faster animation or should I use FuncAnimation
from matplotlib? If it's the latter, then how should I define the update
function?
Thank you for your help.
EDIT 1
import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt
# plotting function
def get_fig(G,current,pred):
for i in G.nodes():
if i==current: G.node[i]['draw'].set_color('red')
elif i==pred: G.node[i]['draw'].set_color('white')
elif i==N: G.node[i]['draw'].set_color('grey')
elif node_visited[i]==1: G.node[i]['draw'].set_color('dodgerblue')
else: G.node[i]['draw'].set_color('powderblue')
# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):
x= i % 10
y= math.floor(i/10)
pos[i]=(x,y)
if i % 10 != 9 and i+1 < 100:
cost[(i,i+1)] = random.randint(0,9)
cost[(i+1,i)] = random.randint(0,9)
if i+10 < 100:
cost[(i,i+10)] = random.randint(0,9)
cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)
# algorithm initialization
plt.figure(1, figsize=(10,10))
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
for i in G.nodes():
if i == N: lab[i] = 0
else: lab[i] = 9999
path[i] = None
node_visited[i] = 0
G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=1,with_labels=True,node_color='powderblue')
for i,j in G.edges():
G[i][j]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(i,j)],width=2)
plt.ion()
plt.draw()
plt.show()
# algorithm main loop
while not SE.empty():
(l,j) = SE.get()
if node_visited[j]==1: continue
node_visited[j] = 1
for i in G.predecessors(j):
insert_in_SE = 0
if lab[i] > cost[(i,j)] + lab[j]:
lab[i] = cost[(i,j)] + lab[j]
path[i] = j
SE.put((lab[i],i))
get_fig(G,j,i)
plt.draw()
plt.pause(0.00001)
plt.close()
EDIT 2
import math
import queue
import random
import networkx as nx
import matplotlib.pyplot as plt
# graph creation
G=nx.DiGraph()
pos={}
cost={}
for i in range(100):
x= i % 10
y= math.floor(i/10)
pos[i]=(x,y)
if i % 10 != 9 and i+1 < 100:
cost[(i,i+1)] = random.randint(0,9)
cost[(i+1,i)] = random.randint(0,9)
if i+10 < 100:
cost[(i,i+10)] = random.randint(0,9)
cost[(i+10,i)] = random.randint(0,9)
G.add_edges_from(cost)
# algorithm initialization
lab={}
path={}
node_visited={}
N = random.randint(0,99)
SE = queue.PriorityQueue()
SE.put((0,N))
cf = plt.figure(1, figsize=(10,10))
ax = cf.add_axes((0,0,1,1))
for i in G.nodes():
if i == N:
lab[i] = 0
G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=1.0,node_color='grey')
else:
lab[i] = 9999
G.node[i]['draw'] = nx.draw_networkx_nodes(G,pos,nodelist=[i],node_size=400,alpha=0.2,node_color='dodgerblue')
path[i] = None
node_visited[i] = 0
for i,j in G.edges():
G[i][j]['draw']=nx.draw_networkx_edges(G,pos,edgelist=[(i,j)],width=3,alpha=0.2,arrows=False)
plt.ion()
plt.show()
ax = plt.gca()
canvas = ax.figure.canvas
background = canvas.copy_from_bbox(ax.bbox)
# algorithm main loop
while not SE.empty():
(l,j) = SE.get()
if node_visited[j]==1: continue
node_visited[j] = 1
if j!=N:
G.node[j]['draw'].set_color('r')
for i in G.predecessors(j):
insert_in_SE = 0
if lab[i] > cost[(i,j)] + lab[j]:
lab[i] = cost[(i,j)] + lab[j]
path[i] = j
SE.put((lab[i],i))
if i!=N:
G.node[i]['draw'].set_alpha(0.7)
G[i][j]['draw'].set_alpha(1.0)
ax.draw_artist(G[i][j]['draw'])
ax.draw_artist(G.node[i]['draw'])
ax.draw_artist(G.node[j]['draw'])
canvas.blit(ax.bbox)
plt.pause(0.0001)
plt.close()