Spaces:
Build error
Build error
import plotly.graph_objects as go | |
import textwrap | |
import re | |
from collections import defaultdict | |
def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info): | |
# Combine nodes into one list with appropriate labels | |
nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence | |
nodes[0] += ' L0' # Paraphrased sentence is level 0 | |
para_len = len(scheme_sentences) | |
for i in range(1, para_len + 1): | |
nodes[i] += ' L1' # Scheme sentences are level 1 | |
for i in range(para_len + 1, len(nodes)): | |
nodes[i] += ' L2' # Sampled sentences are level 2 | |
# Define the highlight_words function | |
def highlight_words(sentence, color_map): | |
for word, color in color_map.items(): | |
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) | |
return sentence | |
# Clean and wrap nodes, and highlight specified words globally | |
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] | |
global_color_map = dict(highlight_info) | |
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] | |
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes] | |
# Function to determine tree levels and create edges dynamically | |
def get_levels_and_edges(nodes): | |
levels = {} | |
edges = [] | |
for i, node in enumerate(nodes): | |
level = int(node.split()[-1][1]) | |
levels[i] = level | |
# Add edges from L0 to all L1 nodes | |
root_node = next(i for i, level in levels.items() if level == 0) | |
for i, level in levels.items(): | |
if level == 1: | |
edges.append((root_node, i)) | |
# Add edges from each L1 node to their corresponding L2 nodes | |
l1_indices = [i for i, level in levels.items() if level == 1] | |
l2_indices = [i for i, level in levels.items() if level == 2] | |
for i, l1_node in enumerate(l1_indices): | |
l2_start = i * 4 | |
for j in range(4): | |
l2_index = l2_start + j | |
if l2_index < len(l2_indices): | |
edges.append((l1_node, l2_indices[l2_index])) | |
# Add edges from each L2 node to their corresponding L3 nodes | |
l2_indices = [i for i, level in levels.items() if level == 2] | |
l3_indices = [i for i, level in levels.items() if level == 3] | |
l2_to_l3_map = {l2_node: [] for l2_node in l2_indices} | |
# Map L3 nodes to L2 nodes | |
for l3_node in l3_indices: | |
l2_node = l3_node % len(l2_indices) | |
l2_to_l3_map[l2_indices[l2_node]].append(l3_node) | |
for l2_node, l3_nodes in l2_to_l3_map.items(): | |
for l3_node in l3_nodes: | |
edges.append((l2_node, l3_node)) | |
return levels, edges | |
# Get levels and dynamic edges | |
levels, edges = get_levels_and_edges(nodes) | |
max_level = max(levels.values(), default=0) | |
# Calculate positions | |
positions = {} | |
level_heights = defaultdict(int) | |
for node, level in levels.items(): | |
level_heights[level] += 1 | |
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} | |
x_gap = 2 | |
l1_y_gap = 10 | |
l2_y_gap = 6 | |
for node, level in levels.items(): | |
if level == 1: | |
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) | |
elif level == 2: | |
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap) | |
else: | |
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap) | |
y_offsets[level] += 1 | |
# Function to highlight words in a wrapped node string | |
def color_highlighted_words(node, color_map): | |
parts = re.split(r'(\{\{.*?\}\})', node) | |
colored_parts = [] | |
for part in parts: | |
match = re.match(r'\{\{(.*?)\}\}', part) | |
if match: | |
word = match.group(1) | |
color = color_map.get(word, 'black') | |
colored_parts.append(f"<span style='color: {color};'>{word}</span>") | |
else: | |
colored_parts.append(part) | |
return ''.join(colored_parts) | |
# Create figure | |
fig = go.Figure() | |
# Add nodes to the figure | |
for i, node in enumerate(wrapped_nodes): | |
colored_node = color_highlighted_words(node, global_color_map) | |
x, y = positions[i] | |
fig.add_trace(go.Scatter( | |
x=[-x], # Reflect the x coordinate | |
y=[y], | |
mode='markers', | |
marker=dict(size=10, color='blue'), | |
hoverinfo='none' | |
)) | |
fig.add_annotation( | |
x=-x, # Reflect the x coordinate | |
y=y, | |
text=colored_node, | |
showarrow=False, | |
xshift=15, | |
align="center", | |
font=dict(size=8), | |
bordercolor='black', | |
borderwidth=1, | |
borderpad=2, | |
bgcolor='white', | |
width=150 | |
) | |
# Add edges to the figure | |
for edge in edges: | |
x0, y0 = positions[edge[0]] | |
x1, y1 = positions[edge[1]] | |
fig.add_trace(go.Scatter( | |
x=[-x0, -x1], # Reflect the x coordinates | |
y=[y0, y1], | |
mode='lines', | |
line=dict(color='black', width=1) | |
)) | |
fig.update_layout( | |
showlegend=False, | |
margin=dict(t=20, b=20, l=20, r=20), | |
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
width=1200, # Adjusted width to accommodate more levels | |
height=1000 # Adjusted height to accommodate more levels | |
) | |
return fig |