Maharshi Gor
First Working commit
193db9d
raw
history blame
6.75 kB
import json
import logging
import re
from collections import Counter
import matplotlib.pyplot as plt
import pandas as pd
def evaluate_buzz(prediction: str, clean_answers: list[str] | str) -> int:
"""Evaluate the buzz of a prediction against the clean answers."""
if isinstance(clean_answers, str):
print("clean_answers is a string")
clean_answers = [clean_answers]
pred = prediction.lower().strip()
if not pred:
return 0
for answer in clean_answers:
answer = answer.strip().lower()
if answer and answer in pred:
print(f"Found {answer} in {pred}")
return 1
return 0
def create_answer_html(answer: str):
"""Create HTML for the answer."""
return f"<div class='answer-header'>Answer:<br>{answer}</div>"
def create_tokens_html(tokens: list[str], eval_points: list[tuple], answer: str, marker_indices: list[int] = None):
"""Create HTML for tokens with hover capability and a colored header for the answer."""
try:
html_parts = []
ep = dict(eval_points)
marker_indices = set(marker_indices) if isinstance(marker_indices, list) else set()
# Add a colored header for the answer
# html_parts.append(create_answer_html(answer))
for i, token in enumerate(tokens):
# Check if this token is a buzz point
values = ep.get(i, (None, 0, 0))
confidence, buzz_point, score = values
# Replace non-word characters for proper display in HTML
display_token = token
if not re.match(r"\w+", token):
display_token = token.replace(" ", "&nbsp;")
# Add buzz marker class if it's a buzz point
if confidence is None:
css_class = ""
elif not buzz_point:
css_class = " guess-point no-buzz"
else:
css_class = f" guess-point buzz-{score}"
token_html = f'<span id="token-{i}" class="token{css_class}" data-index="{i}">{display_token}</span>'
if i in marker_indices:
token_html += "<span style='color: rgba(0,0,255,0.3);'>|</span>"
html_parts.append(token_html)
return f"<div class='token-container'>{''.join(html_parts)}</div>"
except Exception as e:
logging.error(f"Error creating token HTML: {e}", exc_info=True)
return f"<div class='token-container'>Error creating tokens: {str(e)}</div>"
def create_line_plot(eval_points, highlighted_index=-1):
"""Create a Gradio LinePlot of token values with optional highlighting using DataFrame."""
try:
# Create base confidence data
data = []
# Add buzz points to the plot
for i, (v, b) in eval_points:
color = "#ff4444" if b == 0 else "#228b22"
data.append(
{
"position": i,
"value": v,
"type": "buzz",
"highlight": True,
"color": color,
}
)
if highlighted_index >= 0:
# Add vertical line for the highlighted token
data.extend(
[
{
"position": highlighted_index,
"value": 0,
"type": "hover-line",
"color": "#000000",
"highlight": True,
},
{
"position": highlighted_index,
"value": 1,
"type": "hover-line",
"color": "#000000",
"highlight": True,
},
]
)
return pd.DataFrame(data)
except Exception as e:
logging.error(f"Error creating line plot: {e}", exc_info=True)
# Return an empty DataFrame with the expected columns
return pd.DataFrame(columns=["position", "value", "type", "highlight", "color"])
def create_pyplot(tokens, eval_points, highlighted_index=-1):
"""Create a pyplot of token values with optional highlighting."""
plt.style.use("ggplot") # Set theme to grid paper
fig = plt.figure(figsize=(10, 6)) # Set figure size
ax = fig.add_subplot(111)
x = [0]
y = [0]
for i, (v, b, s) in eval_points:
x.append(i + 1)
y.append(v)
ax.plot(x, y, "o--", color="#4698cf")
for i, (v, b, s) in eval_points:
if not b:
continue
color = "green" if s else "red"
ax.plot(i + 1, v, "o", color=color)
if i >= len(tokens):
print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}")
ax.annotate(f"{tokens[i]}", (i + 1, v), textcoords="offset points", xytext=(0, 10), ha="center")
if highlighted_index >= 0:
# Add light vertical line for the highlighted token from 0 to 1
ax.axvline(x=highlighted_index + 1, color="#ff9900", linestyle="--", ymin=0, ymax=1)
ax.set_title("Buzz Confidence")
ax.set_xlabel("Token Index")
ax.set_ylabel("Confidence")
ax.set_xticks(x)
ax.set_xticklabels(x)
return fig
def create_scatter_pyplot(token_positions, scores):
"""Create a scatter plot of token positions and scores."""
plt.style.use("ggplot")
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
counts = Counter(zip(token_positions, scores))
X = []
Y = []
S = []
for (pos, score), size in counts.items():
X.append(pos)
Y.append(score)
S.append(size * 20)
ax.scatter(X, Y, color="#4698cf", s=S)
return fig
def update_plot(highlighted_index, state):
"""Update the plot when a token is hovered; add a vertical line on the plot."""
try:
if not state or state == "{}":
logging.warning("Empty state provided to update_plot")
return pd.DataFrame()
highlighted_index = int(highlighted_index) if highlighted_index else None
logging.info(f"Update plot triggered with token index: {highlighted_index}")
data = json.loads(state)
tokens = data.get("tokens", [])
values = data.get("values", [])
if not tokens or not values:
logging.warning("No tokens or values found in state")
return pd.DataFrame()
# Create updated plot with highlighting of the token point
# plot_data = create_line_plot(values, highlighted_index)
plot_data = create_pyplot(tokens, values, highlighted_index)
return plot_data
except Exception as e:
logging.error(f"Error updating plot: {e}")
return pd.DataFrame()