LSPW / utils /plot.py
fanduluhf's picture
Upload 4 files
f460dc5 verified
import matplotlib.pyplot as plt
import numpy as np
from colorsys import hsv_to_rgb
import string
import networkx as nx
import re
# import osmnx as ox
import matplotlib.animation as animation
from itertools import combinations
import random
#################plot for transcripts############################
def plot_string(text, figsize=(18, 4)):
"""
Plot string with identical colors for identical letters and 0.2x letter width spacing.
Supports lowercase letters, uppercase letters, and underscores.
"""
plt.rcParams['font.family'] = 'Times New Roman'
plt.figure(figsize=figsize)
# Include underscore in the character set
unique_chars = (sorted(set(string.ascii_lowercase))[:10])
hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
for char, hue in zip(unique_chars, hues)}
unique_chars += ['-']
# Add special handling for underscore
color_map['-'] = (0.5, 0.5, 0.5) # Gray color for underscore
unique_chars += (sorted(set(string.ascii_lowercase))[10:])
color_map = {char: hsv_to_rgb(hue, 0.3, 0.8)
for char, hue in zip(unique_chars, hues)}
spacing = 0.1 # Space between letters relative to letter width
width = 1.0 # Width of each letter
total_width = len(text) * width * (1 + spacing) - spacing
for i, char in enumerate(text):
x_pos = i * width * (1 + spacing)
if char == '_':
# Draw underscore as a line slightly below the baseline
plt.plot([x_pos - width/3, x_pos + width/3],
[-0.2, -0.2],
color=color_map['_'],
linewidth=2)
else:
# Regular character plotting
color = color_map[char.lower()]
plt.text(x_pos, 0, char, fontsize=14, color=color,
ha='center', va='center')
plt.xlim(-width/2, total_width - width/2)
plt.ylim(-0.5, 0.5)
plt.axis('off')
plt.tight_layout()
plt.savefig('transcript.svg', format='svg')
#plt.show()
# Example usage
#text = "AACCCBBAAFFDDDEEEECCCCBBAAAEEECCCFFFFAAACCBBAAFFDDDDEEEEECCCCCBBBBAAAEECCCCCFFFAAAACCCBBBBAAAFFFBBBEECCBBBAAEECCCFFAA"
#plot_string(text)
def parse_sequence(sequence):
"""Parse sequence into main path and branches"""
branches = []
main_path = []
current_branch = []
in_branch = False
for part in sequence.split(','):
part = part.strip()
if '(' in part:
in_branch = True
part = part.replace('(', '')
if ')' in part:
in_branch = False
part = part.replace(')', '')
nodes = re.findall(r'([A-Za-z_]\d+)', part)
if len(nodes) == 2:
if in_branch:
current_branch.extend(nodes)
else:
if current_branch:
branches.append(current_branch)
current_branch = []
main_path.extend(nodes)
if current_branch:
branches.append(current_branch)
return main_path, branches
def get_node_number(node):
"""Extract number from node label"""
return int(re.findall(r'\d+', node)[0])
def plot_sequence(sequence, figsize=(8, 3)):
plt.figure(figsize=figsize)
# Assign colors
unique_chars = sorted(set(string.ascii_uppercase))[:10]
hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
for char, hue in zip(unique_chars, hues)}
unique_chars += ['_']
# Add special handling for underscore
color_map['_'] = (0.5, 0.5, 0.5) # Gray color for underscore
color_map.update({char: hsv_to_rgb(hue, 0.3, 0.8)
for char, hue in zip(sorted(set(string.ascii_uppercase))[10:], hues)})
main_path, branches = parse_sequence(sequence)
G = nx.DiGraph()
# Calculate positions
pos = {}
x_spacing = 1
y_spacing = 0.5
# Group nodes by their number
nodes_by_number = {}
all_nodes = set(main_path)
for branch in branches:
all_nodes.update(branch)
for node in all_nodes:
num = get_node_number(node)
if num not in nodes_by_number:
nodes_by_number[num] = []
nodes_by_number[num].append(node)
# Position nodes
for num in sorted(nodes_by_number.keys()):
nodes = nodes_by_number[num]
x = (num - 1) * x_spacing
if len(nodes) == 1:
pos[nodes[0]] = (x, 0)
else:
# Center branching nodes vertically
total_height = (len(nodes) - 1) * y_spacing
start_y = -total_height / 2
for i, node in enumerate(sorted(nodes)):
pos[node] = (x, start_y + i * y_spacing)
# Add edges
for i in range(0, len(main_path)-1, 2):
G.add_edge(main_path[i], main_path[i+1])
for branch in branches:
for i in range(0, len(branch)-1, 2):
G.add_edge(branch[i], branch[i+1])
# Draw arrows
nx.draw_networkx_edges(G, pos, edge_color='gray',
arrowsize=10, width=1.5,
arrowstyle='->')
# Draw nodes
for node in G.nodes():
letter = node[0]
color = color_map[letter]
circle = plt.Circle(pos[node], 0.1,
color=color, alpha=0.3)
plt.gca().add_patch(circle)
plt.text(pos[node][0], pos[node][1], node[0],
color=color, fontsize=8,
ha='center', va='center',
fontweight='bold')
plt.axis('equal')
plt.axis('off')
plt.tight_layout()
plt.savefig('transcript.svg', format='svg')
# Example usage
#sequence = "A1->C2, C2->B3, B3->A4, A4->F5, (F5->B6, F5->D6), (B6->E7, D6->E7), E7->C8, C8->B9, B9->A10, A10->E11, E11->C12, C12->F13"
#plot_sequence(sequence, (5.4,2))
#################plot for transcripts############################
#################plot for 2D trajectories############################
def plot_routes_animation(G, routes, colors, output_file, fps=20, duration_sec=10):
"""
Create an animation showing routes appearing dynamically.
Args:
G (networkx.MultiDiGraph): Street network graph
routes (list): List of routes (each route is a list of nodes)
colors (list): List of colors for each route
output_file (str): Output filename (should end with .gif or .mp4)
fps (int): Frames per second
duration_sec (int): Total animation duration in seconds
"""
# Create figure and axis
fig, ax = plt.subplots(figsize=(10, 8))
plt.rcParams['font.family'] = 'Times New Roman'
# Plot the base map
# ox.plot_graph(G, ax=ax, show=False, close=False,
# edge_color='gray', edge_alpha=0.2, node_size=0)
# Create empty route lines
route_lines = []
route_points = []
# Initialize all routes as empty
for color in colors:
line, = ax.plot([], [], marker='D', color=color, linewidth=2, alpha=0.8, zorder=2)
route_lines.append(line)
route_points.append([])
# Extract coordinates for all routes
all_route_coords = []
for route in routes:
coords = []
for node in route:
x = G.nodes[node]['x']
y = G.nodes[node]['y']
coords.append((x, y))
all_route_coords.append(coords)
# Calculate total number of frames
total_frames = fps * duration_sec
# Animation update function
def update(frame):
# Calculate progress (0 to 1)
progress = frame / total_frames
# Update each route
for i, coords in enumerate(all_route_coords):
# Determine how many points to show for this route
route_progress = min(1.0, progress * len(routes) - i)
if route_progress <= 0:
# Route hasn't started yet
route_lines[i].set_data([], [])
continue
# Calculate number of points to show
num_points = max(2, int(route_progress * len(coords)))
# Get coordinates to display
visible_coords = coords[:num_points]
xs, ys = zip(*visible_coords) if visible_coords else ([], [])
# Update line data
route_lines[i].set_data(xs, ys)
# Update legend based on which routes are visible
visible_routes = [i for i, line in enumerate(route_lines)
if len(line.get_xdata()) > 0]
if visible_routes:
# Update legend with only visible routes
ax.legend([route_lines[i] for i in visible_routes],
[f'Period {i+1}' for i in visible_routes],
loc='upper right', prop={'size': 14},
bbox_to_anchor=(1, 1))
return route_lines
# Create animation
ani = animation.FuncAnimation(
fig, update, frames=total_frames,
interval=1000/fps, blit=True
)
# Tight layout
plt.tight_layout()
# Save animation
if output_file.endswith('.gif'):
ani.save(output_file, writer='pillow', fps=fps, dpi=150)
else:
# For MP4, use ffmpeg
writer = animation.FFMpegWriter(fps=fps, bitrate=5000)
ani.save(output_file, writer=writer, dpi=150)
plt.close()
print(f"Animation saved to {output_file}")
# Example usage:
# plot_routes_animation(G, routes, colors, "route_animation.gif")
# For MP4: plot_routes_animation(G, routes, colors, "route_animation.mp4")
def plot_routes(G, routes, colors, output_file):
"""
Plot multiple routes on the same map.
Args:
G (networkx.MultiDiGraph): Street network graph
routes (list): List of routes (each route is a list of nodes)
colors (list): List of colors for each route
output_file (str): Output filename
"""
# Create figure and axis
fig, ax = plt.subplots(figsize=(8, 6))
plt.rcParams['font.family'] = 'Times New Roman'
# Plot the base map
# ox.plot_graph(G, ax=ax, show=False, close=False,
# edge_color='gray', edge_alpha=0.2, node_size=0)
# Create empty list to store route lines for legend
route_lines = []
# Plot each route
for route, color in zip(routes, colors):
# Extract the coordinates for each node in the route
xs = []
ys = []
for node in route:
# Get node coordinates
x = G.nodes[node]['x']
y = G.nodes[node]['y']
xs.append(x)
ys.append(y)
# Plot the route
line = ax.plot(xs, ys, marker='D', color=color, linewidth=2, alpha=0.2, zorder=2)[0]
route_lines.append(line)
# Add legend
ax.legend(route_lines,
[f'Period {i+1}' for i in range(len(routes))],
loc='upper right', prop={'size': 14},
bbox_to_anchor=(1.25, 0.85))
# Adjust layout and save
plt.tight_layout()
plt.savefig(output_file, dpi=100, bbox_inches='tight')
plt.show()
plt.close()
#################plot for 2D trajectories############################
def plot_task_2(obs_len, gt_seq_len, pred_seq_len, figsize_w=10, title=None):
"""
Plot both GT and Pred timelines in the same figure with aligned scales.
Args:
obs_len: Length of observation period
gt_seq_len: Total sequence length for ground truth
pred_seq_len: Total sequence length for prediction
figsize_w: Width of the figure
title: Optional title for the figure
"""
# Use the maximum sequence length to determine the x-axis limits
max_seq_len = max(gt_seq_len, pred_seq_len)
# Create figure with two subplots, one for GT and one for Pred
fig, axes = plt.subplots(2, 1, figsize=(figsize_w, 2.5), gridspec_kw={'hspace': 0.3})
plt.rcParams['font.family'] = 'Times New Roman'
# Add title if provided
if title:
fig.suptitle(title, fontsize=14, fontweight='bold')
# Create consistent bar heights and label offset
bar_height = 0.5
label_offset = max_seq_len * 0.15 # Proportional offset based on sequence length
# GT plot (top)
y_position = 0
axes[0].barh(y_position, obs_len, height=bar_height, left=0, color='lightgray')
axes[0].barh(y_position, gt_seq_len+1-obs_len, height=bar_height, left=obs_len, color='lightgreen')
axes[0].text(-label_offset, y_position, "GT:", fontsize=12, fontweight='bold', verticalalignment='center')
# Pred plot (bottom)
axes[1].barh(y_position, obs_len, height=bar_height, left=0, color='lightgray')
axes[1].barh(y_position, pred_seq_len+1-obs_len, height=bar_height, left=obs_len, color='lightblue')
axes[1].text(-label_offset, y_position, "Pred:", fontsize=12, fontweight='bold', verticalalignment='center')
# Configure both axes consistently
for ax in axes:
# Set consistent x-limits for alignment
ax.set_xlim(-label_offset, max_seq_len+1)
ax.set_ylim(-0.5, 0.5)
ax.set_yticks([])
# Remove the box/frame
for spine in ax.spines.values():
spine.set_visible(False)
# Add a thin line below the bar for better visibility
ax.axhline(y_position - bar_height/2, color='black', linewidth=0.5)
# Set tick marks for each plot
axes[0].set_xticks([0, obs_len, gt_seq_len])
axes[1].set_xticks([0, obs_len, pred_seq_len])
plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1])
return fig
def plot_task_3(gt_seq_len, GT_start, GT_end, pred_start, pred_end, figsize_w=10, title=None):
"""
Plot both GT and Pred timelines in the same figure with aligned scales.
Args:
gt_seq_len: Total sequence length for ground truth
GT_start: Start position of GT highlight bar
GT_end: End position of GT highlight bar
pred_start: Start position of prediction highlight bar
pred_end: End position of prediction highlight bar
figsize_w: Width of the figure
title: Optional title for the figure
"""
# Use the maximum sequence length to determine the x-axis limits
max_seq_len = gt_seq_len
# Create figure with two subplots, one for GT and one for Pred
fig, axes = plt.subplots(2, 1, figsize=(figsize_w, 2.5), gridspec_kw={'hspace': 0.3})
plt.rcParams['font.family'] = 'Times New Roman'
# Add title if provided
if title:
fig.suptitle(title, fontsize=14, fontweight='bold')
# Create consistent bar heights and label offset
bar_height = 0.5
label_offset = max_seq_len * 0.15 # Proportional offset based on sequence length
# GT plot (top)
y_position = 0
# Plot full lightgray bar for GT
axes[0].barh(y_position, gt_seq_len, height=bar_height, left=0, color='lightgray')
# Plot lightgreen bar within GT from GT_start to GT_end
axes[0].barh(y_position, GT_end - GT_start, height=bar_height, left=GT_start, color='lightgreen')
axes[0].text(-label_offset, y_position, "GT:", fontsize=12, fontweight='bold', verticalalignment='center')
# Pred plot (bottom)
# Plot full lightgray bar for Pred
axes[1].barh(y_position, gt_seq_len, height=bar_height, left=0, color='lightgray')
# Plot lightblue bar within Pred from pred_start to pred_end
axes[1].barh(y_position, pred_end - pred_start, height=bar_height, left=pred_start, color='lightblue')
axes[1].text(-label_offset, y_position, "Pred:", fontsize=12, fontweight='bold', verticalalignment='center')
# Configure both axes consistently
for ax in axes:
# Set consistent x-limits for alignment
ax.set_xlim(-label_offset, max_seq_len+1)
ax.set_ylim(-0.5, 0.5)
ax.set_yticks([])
# Remove the box/frame
for spine in ax.spines.values():
spine.set_visible(False)
# Add a thin line below the bar for better visibility
ax.axhline(y_position - bar_height/2, color='black', linewidth=0.5)
# Set tick marks for each plot
axes[0].set_xticks([0, GT_start, GT_end, gt_seq_len])
axes[1].set_xticks([0, pred_start, pred_end, gt_seq_len])
plt.tight_layout(rect=[0, 0, 1, 0.95] if title else [0, 0, 1, 1])
return fig
import string
def plot_images_with_token(images, tokens, n_rows = 2):
assert len(images) == len(tokens), "Each image must have a corresponding token"
n_images = len(images)
# Calculate rows and columns for grid layout
n_cols = (n_images + 1) // n_rows # Ceiling division to handle odd number of images
# Create a figure to display the images
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
plt.rcParams['font.family'] = 'Times New Roman'
unique_chars = (sorted(set(string.ascii_lowercase))[:10])
hues = np.linspace(0, 1, len(unique_chars), endpoint=False)
color_map = {char: hsv_to_rgb(hue, 0.8, 0.9)
for char, hue in zip(unique_chars, hues)}
# Make axes a 2D array even if there's just one column
if n_cols == 1:
axes = axes.reshape(-1, 1)
# Flatten axes for easy iteration if there are multiple columns
axes_flat = axes.flatten()
for i, (image, token) in enumerate(zip(images, tokens)):
if i < len(axes_flat):
color = color_map[token.lower()]
axes_flat[i].imshow(image)
axes_flat[i].set_title(token, color=color, size=50)
axes_flat[i].axis('off') # Hide axes
# Hide any unused subplots
for j in range(i + 1, n_rows * n_cols):
if j < len(axes_flat):
axes_flat[j].axis('off')
fig.delaxes(axes_flat[j])
plt.tight_layout()
plt.savefig('anchors.jpg', bbox_inches='tight', pad_inches=0)
#plt.show()