|
import json |
|
import logging |
|
import re |
|
from collections import Counter |
|
|
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
|
|
|
|
def evaluate_buzz(prediction: str, clean_answers: list[str] | str) -> int: |
|
"""Evaluate the buzz of a prediction against the clean answers.""" |
|
if isinstance(clean_answers, str): |
|
print("clean_answers is a string") |
|
clean_answers = [clean_answers] |
|
pred = prediction.lower().strip() |
|
if not pred: |
|
return 0 |
|
for answer in clean_answers: |
|
answer = answer.strip().lower() |
|
if answer and answer in pred: |
|
print(f"Found {answer} in {pred}") |
|
return 1 |
|
return 0 |
|
|
|
|
|
def create_answer_html(answer: str): |
|
"""Create HTML for the answer.""" |
|
return f"<div class='answer-header'>Answer:<br>{answer}</div>" |
|
|
|
|
|
def create_tokens_html(tokens: list[str], eval_points: list[tuple], answer: str, marker_indices: list[int] = None): |
|
"""Create HTML for tokens with hover capability and a colored header for the answer.""" |
|
try: |
|
html_parts = [] |
|
ep = dict(eval_points) |
|
marker_indices = set(marker_indices) if isinstance(marker_indices, list) else set() |
|
|
|
|
|
|
|
|
|
for i, token in enumerate(tokens): |
|
|
|
values = ep.get(i, (None, 0, 0)) |
|
confidence, buzz_point, score = values |
|
|
|
|
|
display_token = token |
|
if not re.match(r"\w+", token): |
|
display_token = token.replace(" ", " ") |
|
|
|
|
|
if confidence is None: |
|
css_class = "" |
|
elif not buzz_point: |
|
css_class = " guess-point no-buzz" |
|
else: |
|
css_class = f" guess-point buzz-{score}" |
|
|
|
token_html = f'<span id="token-{i}" class="token{css_class}" data-index="{i}">{display_token}</span>' |
|
if i in marker_indices: |
|
token_html += "<span style='color: rgba(0,0,255,0.3);'>|</span>" |
|
html_parts.append(token_html) |
|
|
|
return f"<div class='token-container'>{''.join(html_parts)}</div>" |
|
except Exception as e: |
|
logging.error(f"Error creating token HTML: {e}", exc_info=True) |
|
return f"<div class='token-container'>Error creating tokens: {str(e)}</div>" |
|
|
|
|
|
def create_line_plot(eval_points, highlighted_index=-1): |
|
"""Create a Gradio LinePlot of token values with optional highlighting using DataFrame.""" |
|
try: |
|
|
|
data = [] |
|
|
|
|
|
for i, (v, b) in eval_points: |
|
color = "#ff4444" if b == 0 else "#228b22" |
|
data.append( |
|
{ |
|
"position": i, |
|
"value": v, |
|
"type": "buzz", |
|
"highlight": True, |
|
"color": color, |
|
} |
|
) |
|
|
|
if highlighted_index >= 0: |
|
|
|
data.extend( |
|
[ |
|
{ |
|
"position": highlighted_index, |
|
"value": 0, |
|
"type": "hover-line", |
|
"color": "#000000", |
|
"highlight": True, |
|
}, |
|
{ |
|
"position": highlighted_index, |
|
"value": 1, |
|
"type": "hover-line", |
|
"color": "#000000", |
|
"highlight": True, |
|
}, |
|
] |
|
) |
|
|
|
return pd.DataFrame(data) |
|
except Exception as e: |
|
logging.error(f"Error creating line plot: {e}", exc_info=True) |
|
|
|
return pd.DataFrame(columns=["position", "value", "type", "highlight", "color"]) |
|
|
|
|
|
def create_pyplot(tokens, eval_points, highlighted_index=-1): |
|
"""Create a pyplot of token values with optional highlighting.""" |
|
plt.style.use("ggplot") |
|
fig = plt.figure(figsize=(10, 6)) |
|
ax = fig.add_subplot(111) |
|
x = [0] |
|
y = [0] |
|
for i, (v, b, s) in eval_points: |
|
x.append(i + 1) |
|
y.append(v) |
|
|
|
ax.plot(x, y, "o--", color="#4698cf") |
|
for i, (v, b, s) in eval_points: |
|
if not b: |
|
continue |
|
color = "green" if s else "red" |
|
ax.plot(i + 1, v, "o", color=color) |
|
if i >= len(tokens): |
|
print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}") |
|
ax.annotate(f"{tokens[i]}", (i + 1, v), textcoords="offset points", xytext=(0, 10), ha="center") |
|
|
|
if highlighted_index >= 0: |
|
|
|
ax.axvline(x=highlighted_index + 1, color="#ff9900", linestyle="--", ymin=0, ymax=1) |
|
|
|
ax.set_title("Buzz Confidence") |
|
ax.set_xlabel("Token Index") |
|
ax.set_ylabel("Confidence") |
|
ax.set_xticks(x) |
|
ax.set_xticklabels(x) |
|
return fig |
|
|
|
|
|
def create_scatter_pyplot(token_positions, scores): |
|
"""Create a scatter plot of token positions and scores.""" |
|
plt.style.use("ggplot") |
|
fig = plt.figure(figsize=(10, 6)) |
|
ax = fig.add_subplot(111) |
|
|
|
counts = Counter(zip(token_positions, scores)) |
|
X = [] |
|
Y = [] |
|
S = [] |
|
for (pos, score), size in counts.items(): |
|
X.append(pos) |
|
Y.append(score) |
|
S.append(size * 20) |
|
|
|
ax.scatter(X, Y, color="#4698cf", s=S) |
|
|
|
return fig |
|
|
|
|
|
def update_plot(highlighted_index, state): |
|
"""Update the plot when a token is hovered; add a vertical line on the plot.""" |
|
try: |
|
if not state or state == "{}": |
|
logging.warning("Empty state provided to update_plot") |
|
return pd.DataFrame() |
|
|
|
highlighted_index = int(highlighted_index) if highlighted_index else None |
|
logging.info(f"Update plot triggered with token index: {highlighted_index}") |
|
|
|
data = json.loads(state) |
|
tokens = data.get("tokens", []) |
|
values = data.get("values", []) |
|
|
|
if not tokens or not values: |
|
logging.warning("No tokens or values found in state") |
|
return pd.DataFrame() |
|
|
|
|
|
|
|
plot_data = create_pyplot(tokens, values, highlighted_index) |
|
return plot_data |
|
except Exception as e: |
|
logging.error(f"Error updating plot: {e}") |
|
return pd.DataFrame() |
|
|