Spaces:
Running
Running
import os, json, datetime, threading, requests, random, re, html | |
from typing import List, Dict, Any | |
import gradio as gr | |
import gspread | |
from google.oauth2.service_account import Credentials | |
from gspread.exceptions import WorksheetNotFound | |
import time | |
ENDPOINT_ID_A = os.getenv("ENDPOINT_ID_A") | |
ENDPOINT_ID_B = os.getenv("ENDPOINT_ID_B") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
SERVICE_ACCOUNT_INFO = os.getenv("GCP_SERVICE_ACCOUNT_JSON") | |
SCOPES = [ | |
"https://www.googleapis.com/auth/spreadsheets", | |
"https://www.googleapis.com/auth/drive", | |
] | |
credentials = Credentials.from_service_account_info( | |
json.loads(SERVICE_ACCOUNT_INFO), scopes=SCOPES | |
) | |
if not ENDPOINT_ID_A: | |
raise ValueError("ENDPOINT_ID is not set") | |
if not ENDPOINT_ID_B: | |
raise ValueError("ENDPOINT_ID_B is not set") | |
def _call_retriever_system1(payload: Dict[str, Any]) -> Dict[str, Any]: | |
headers = {"Accept": "application/json", "Content-Type": "application/json"} | |
if HF_TOKEN: | |
headers["Authorization"] = f"Bearer {HF_TOKEN}" | |
try: | |
r = requests.post(ENDPOINT_ID_A, json=payload, headers=headers, timeout=60) | |
r.raise_for_status() | |
return r.json() | |
except requests.exceptions.RequestException as e: | |
raise RuntimeError("Error: failed to contact retriever system 1. Please try again.") | |
def _call_retriever_system2(query: str, k: int) -> List[Dict[str, Any]]: | |
payload = {"query": [query], "num_results": str(k)} | |
try: | |
r = requests.post(ENDPOINT_ID_B, json=payload, timeout=60) | |
r.raise_for_status() | |
data = r.json() | |
return data[0] if isinstance(data, list) and data else [] | |
except requests.exceptions.RequestException as e: | |
raise RuntimeError("Error: failed to contact retriever system 2. Please try again.") | |
gc = gspread.authorize(credentials) | |
worksheet = gc.open("arena_votes").sheet1 | |
vote_worksheet = gc.open("arena_votes").worksheet("individual_votes") | |
SHEET_LOCK = threading.Lock() | |
VOTE_UI = ["Retriever A better", "Retriever B better", "Tie", "Both are bad"] | |
def _save_vote(choice: str, query: str, | |
ret_a_json: Dict[str, Any], ret_b_json: List[Dict[str, Any]], | |
sys1_is_a: bool) -> gr.Textbox: | |
if not choice: | |
return gr.update(value="**Please pick a system.**", visible=True) | |
if choice == "Retriever A better": | |
actual_winner = "System1 better" if sys1_is_a else "System2 better" | |
elif choice == "Retriever B better": | |
actual_winner = "System2 better" if sys1_is_a else "System1 better" | |
else: | |
actual_winner = choice | |
payload = { | |
"retriever_a": ret_a_json, | |
"retriever_b": ret_b_json, | |
"sys1_is_a": sys1_is_a, | |
} | |
row = [ | |
datetime.datetime.utcnow().isoformat(timespec="seconds"), | |
query, | |
actual_winner, | |
json.dumps(payload, ensure_ascii=False), | |
"User Preference" | |
] | |
with SHEET_LOCK: | |
worksheet.append_row(row, value_input_option="RAW") | |
return gr.update(value="**Vote recorded β thanks!**", visible=True) | |
def _save_individual_vote(formal_statement: str, vote_decision: str, query: str, system: str, rank: int, payload: Dict[str, Any]) -> str: | |
row = [ | |
datetime.datetime.utcnow().isoformat(timespec="seconds"), | |
query, | |
formal_statement, | |
vote_decision, | |
system, | |
rank, | |
json.dumps(payload, ensure_ascii=False) | |
] | |
try: | |
with SHEET_LOCK: | |
vote_worksheet.append_row(row, value_input_option="RAW") | |
return "Vote recorded!" | |
except Exception as e: | |
return f"Error writing to Google Sheets: {str(e)}" | |
def _process_informal_statement(text: str) -> str: | |
if not text: | |
return "" | |
text = re.sub(r'`([^`]+)`', r'<code>\1</code>', text) | |
return text | |
def _render_system1_results(res: List[Dict[str, Any]], title: str = "Retriever A", query: str = ""): | |
if not res: | |
return f"<p>No results from {title}.</p>" | |
def make_row(i: int, r: Dict[str, Any]) -> str: | |
formal_text = r['formal_statement'] | |
informal_text = r['informal_statement'] | |
formal_escaped = html.escape(formal_text) | |
doc_url = r.get('url', '') | |
doc_button = f'<button class="doc-button" onclick="window.open(\'{doc_url}\', \'_blank\')" title="View documentation">Doc</button>' if doc_url else '' | |
formal_cell = ( | |
f'<div class="copy-container">' | |
f'<code style="white-space:pre-wrap">{formal_escaped}</code>' | |
f'<div class="button-container">' | |
f'<div class="left-buttons">' | |
f'{doc_button}' | |
f'<button class="copy-button" data-copy-text="{formal_escaped}" onclick="copyToClipboard(this.getAttribute(\'data-copy-text\'), this)">Copy</button>' | |
f'</div>' | |
f'<div class="right-buttons">' | |
f'<button class="vote-button upvote" data-formal="{formal_escaped}" data-query="{html.escape(query)}" data-system="System1" data-rank="{i}" onclick="voteOnResultSafe(this, \'Upvote\')">π</button>' | |
f'<button class="vote-button downvote" data-formal="{formal_escaped}" data-query="{html.escape(query)}" data-system="System1" data-rank="{i}" onclick="voteOnResultSafe(this, \'Downvote\')">π</button>' | |
f'</div>' | |
f'</div>' | |
f'</div>' | |
) | |
informal_cell = ( | |
f'<div class="copy-container">' | |
f'<span style="white-space:pre-wrap">{_process_informal_statement(informal_text)}</span>' | |
f'<div class="button-container">' | |
f'<button class="copy-button" data-copy-text="{informal_text}" onclick="copyToClipboard(this.getAttribute(\'data-copy-text\'), this)">Copy</button>' | |
f'</div>' | |
f'</div>' | |
) | |
return f"<tr><td>{i}</td><td>{formal_cell}</td><td>{informal_cell}</td></tr>" | |
rows = "\n".join(make_row(i, r) for i, r in enumerate(res, 1)) | |
return ( | |
f"<h3>{title}</h3>" | |
"<table><thead><tr><th>Rank</th>" | |
"<th>Formal statement</th><th>Informal statement</th></tr></thead>" | |
f"<tbody>{rows}</tbody></table>" | |
) | |
def _render_system2_results(res: List[Dict[str, Any]], title: str = "Retriever B", query: str = ""): | |
if not res: | |
return f"<p>No results from {title}.</p>" | |
def row(i: int, e: Dict[str, Any]) -> str: | |
r = e.get("result", {}) | |
kind, name = r.get("kind","").strip(), ".".join(r.get("name", [])) | |
sig, val = r.get("signature", ""), r.get("value","").lstrip() | |
formal_text = f"{kind} {name}{sig} {val}".strip() | |
informal_text = r.get('informal_description','') | |
formal_escaped = html.escape(formal_text) | |
full_name = ".".join(r.get("name", [])) | |
doc_url = f"https://leanprover-community.github.io/mathlib4_docs/find/?pattern={full_name}#doc" if full_name else "" | |
doc_button = f'<button class="doc-button" onclick="window.open(\'{doc_url}\', \'_blank\')" title="View documentation">Doc</button>' if doc_url else '' | |
formal_cell = ( | |
f'<div class="copy-container">' | |
f'<code style="white-space:pre-wrap">{formal_escaped}</code>' | |
f'<div class="button-container">' | |
f'<div class="left-buttons">' | |
f'{doc_button}' | |
f'<button class="copy-button" data-copy-text="{formal_escaped}" onclick="copyToClipboard(this.getAttribute(\'data-copy-text\'), this)">Copy</button>' | |
f'</div>' | |
f'<div class="right-buttons">' | |
f'<button class="vote-button upvote" data-formal="{formal_escaped}" data-query="{html.escape(query)}" data-system="System2" data-rank="{i}" onclick="voteOnResultSafe(this, \'Upvote\')">π</button>' | |
f'<button class="vote-button downvote" data-formal="{formal_escaped}" data-query="{html.escape(query)}" data-system="System2" data-rank="{i}" onclick="voteOnResultSafe(this, \'Downvote\')">π</button>' | |
f'</div>' | |
f'</div>' | |
f'</div>' | |
) | |
informal_cell = ( | |
f'<div class="copy-container">' | |
f'<span style="white-space:pre-wrap">{_process_informal_statement(informal_text)}</span>' | |
f'<div class="button-container">' | |
f'<button class="copy-button" data-copy-text="{informal_text}" onclick="copyToClipboard(this.getAttribute(\'data-copy-text\'), this)">Copy</button>' | |
f'</div>' | |
f'</div>' | |
) | |
return f"<tr><td>{i}</td><td>{formal_cell}</td><td>{informal_cell}</td></tr>" | |
rows = "\n".join(row(i,e) for i,e in enumerate(res,1)) | |
return ( | |
f"<h3>{title}</h3>" | |
"<table><thead><tr><th>Rank</th>" | |
"<th>Formal statement</th><th>Informal statement</th></tr></thead>" | |
f"<tbody>{rows}</tbody></table>" | |
) | |
INSTRUCTIONS_MD = """ | |
## Supported query types | |
### 1. Informalized statement | |
Enter an informal translation of a formal Lean statement, and find relevant Lean statements.\n | |
**Example:** Let L/K be a field extension and let x, y β L be algebraic elements over K with the same minimal polynomial. Then the K-algebra isomorphism algEquiv between the simple field extensions K(x) and K(y) maps the generator x of K(x) to the generator y of K(y); i.e. algEquiv(x) = y. | |
### 2. User question | |
Ask any question about Lean statements.\n | |
**Example:** I'm working with algebraic elements over a field extension β¦ Does this imply that the minimal polynomials of `x` and `y` are equal? | |
### 3. Statement definition | |
Enter any fragment or the whole statement definition, and find statements that match the entered content. Note that the query syntax doesn't need to be perfect.\n | |
**Example:** theorem restrict Ioi: restrict Ioi e = restrict Ici e | |
### 4. Code snippets | |
Enter code snippets of statements, and find statements that use the same or similar code.\n | |
**Example:** rcases hf with β¨x, rflβ© exact β¨x, fun _ => rflβ© | |
""" | |
# Gradio app | |
CUSTOM_CSS = """ | |
html,body{margin:0;padding:0;width:100%;} | |
.gradio-container,.gradio-container .block{ | |
max-width:none!important;width:100%!important;padding:0 0.5rem; | |
} | |
/* Tables and code blocks */ | |
table{width:100%;border-collapse:collapse;font-size:0.9rem;} | |
th,td{border:1px solid #ddd;padding:6px;vertical-align:top;} | |
th{background:#f5f5f5;font-weight:600;} | |
code{background:#f8f8f8;border:1px solid #eee;border-radius:4px;padding:2px 4px;color:#333;} | |
td code{ | |
background:#f0f0f0; | |
border:1px solid #ddd; | |
border-radius:3px; | |
padding:1px 3px; | |
font-size:0.9em; | |
font-family:Monaco, Consolas, "Courier New", monospace; | |
} | |
/* Dark mode support */ | |
@media (prefers-color-scheme: dark) { | |
table{color:#e0e0e0;} | |
th,td{border:1px solid #555;} | |
th{background:#2a2a2a;color:#e0e0e0;} | |
code{background:#2a2a2a;border:1px solid #555;color:#e0e0e0;} | |
td code{ | |
background:#333; | |
border:1px solid #555; | |
color:#e0e0e0; | |
} | |
} | |
/* Arena voting controls */ | |
#vote_area{ | |
margin-top:1.5rem; | |
align-items:center; | |
justify-content:center; | |
gap:1rem; | |
flex-wrap:wrap; | |
border:none!important; | |
box-shadow:none!important; | |
} | |
#vote_radio_col{ | |
display:flex; | |
flex-direction:column; | |
align-items:center; | |
gap:0.75rem; | |
} | |
#vote_radio .gr-radio{display:flex;gap:0.75rem;} | |
#vote_radio label{ | |
padding:4px 14px; | |
border:1px solid #ccc; | |
border-radius:8px; | |
cursor:pointer; | |
user-select:none; | |
transition:all .15s ease; | |
} | |
#vote_radio input[type="radio"]:checked + label{ | |
background:#0066ff;color:#fff;border-color:#0066ff; | |
} | |
#submit_btn button{ | |
padding:0.65rem 1.6rem; | |
font-weight:600;font-size:1rem;border-radius:8px; | |
} | |
#vote_status{margin-top:0.5rem;text-align:center;} /* works for Markdown */ | |
#lf_header.gr-column{ | |
display:flex !important; | |
flex-direction:row !important; | |
align-items:center !important; | |
justify-content:center !important; | |
gap:1rem !important; | |
width:100% !important; | |
margin:1rem auto !important; | |
padding:0 !important; | |
} | |
/* Alternative selector in case the above doesn't work */ | |
div#lf_header{ | |
display:flex !important; | |
flex-direction:row !important; | |
align-items:center !important; | |
justify-content:center !important; | |
gap:1rem !important; | |
width:100% !important; | |
margin:1rem auto !important; | |
padding:0 !important; | |
} | |
#lf_logo{ | |
width:60px !important; | |
height:60px !important; | |
flex:0 0 60px !important; | |
overflow:hidden; | |
} | |
#lf_logo .gr-image-toolbar{ | |
display:none !important; | |
} | |
/* Fix the HTML container for the title */ | |
#lf_header .gr-html{ | |
width:auto !important; | |
min-width:auto !important; | |
max-width:none !important; | |
flex:0 0 auto !important; | |
} | |
.lf-title{ | |
margin:0 !important; | |
padding:0 !important; | |
font-size:1.6rem !important; | |
font-weight:700 !important; | |
white-space:normal !important; | |
color:#333 !important; | |
line-height:1.2 !important; | |
text-align:center !important; | |
} | |
.gr-accordion-header { | |
font-size: 1.05rem; | |
font-weight: 600; | |
cursor: pointer; | |
padding: 0.4rem 0; | |
} | |
/* Progress bar overlay fix for HF Spaces */ | |
div.gradio-modal[aria-label="progress"] /* outer overlay on HF */ | |
{ | |
position: fixed !important; /* pull it out of the normal flow */ | |
top: 50% !important; /* perfectly centred in the viewport */ | |
left: 50% !important; | |
transform: translate(-50%, -50%) !important; | |
z-index: 2000 !important; | |
width: clamp(260px, 70vw, 440px) !important; | |
max-height: 140px !important; | |
padding: 20px 24px !important; | |
/* optional aesthetic tweaks β remove if you like HFβs defaults */ | |
background: var(--block-background-fill, #fff) !important; | |
border: 1px solid #ddd !important; | |
border-radius: 8px !important; | |
box-shadow: 0 4px 12px rgba(0,0,0,.15) !important; | |
pointer-events: none !important; | |
} | |
/* Copy button styling */ | |
.copy-container { | |
position: relative; | |
display: block; | |
width: 100%; | |
min-height: 1.5em; | |
} | |
.copy-button { | |
background: rgba(0, 0, 0, 0.1); | |
border: 1px solid #ccc; | |
border-radius: 4px; | |
padding: 4px 8px; | |
font-size: 12px; | |
cursor: pointer; | |
transition: background-color 0.2s ease; | |
color: #666; | |
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
line-height: 1.2; | |
white-space: nowrap; | |
vertical-align: baseline; | |
box-sizing: border-box; | |
} | |
.copy-button:hover { | |
background: rgba(0, 0, 0, 0.2); | |
} | |
.copy-button:active { | |
background: rgba(0, 0, 0, 0.3); | |
} | |
td { | |
position: relative; | |
} | |
/* Query input styling */ | |
#query_input textarea { | |
resize: vertical !important; | |
min-height: 100px !important; | |
max-height: 400px !important; | |
font-size: 14px !important; | |
line-height: 1.5 !important; | |
padding: 12px !important; | |
} | |
#query_input .gr-textbox { | |
min-height: 100px !important; | |
} | |
/* Alternative selectors for different Gradio versions */ | |
textarea[data-testid="textbox"] { | |
resize: vertical !important; | |
min-height: 100px !important; | |
max-height: 400px !important; | |
} | |
div[data-testid="textbox"] textarea { | |
resize: vertical !important; | |
min-height: 100px !important; | |
max-height: 400px !important; | |
} | |
/* Vote buttons styling */ | |
.vote-button { | |
background: transparent; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
padding: 4px 8px; | |
cursor: pointer; | |
font-size: 16px; | |
line-height: 1.2; | |
transition: all 0.2s ease; | |
opacity: 0.6; | |
vertical-align: baseline; | |
box-sizing: border-box; | |
} | |
.vote-button:hover { | |
opacity: 1; | |
transform: scale(1.1); | |
} | |
.vote-button.upvote { | |
color: #28a745; | |
border-color: #28a745; | |
} | |
.vote-button.downvote { | |
color: #dc3545; | |
border-color: #dc3545; | |
} | |
/* Doc button styling */ | |
.doc-button { | |
background: rgba(0, 123, 255, 0.1); | |
border: 1px solid #007bff; | |
border-radius: 4px; | |
padding: 4px 8px; | |
font-size: 12px; | |
cursor: pointer; | |
transition: background-color 0.2s ease; | |
color: #007bff; | |
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
line-height: 1.2; | |
white-space: nowrap; | |
vertical-align: baseline; | |
box-sizing: border-box; | |
} | |
.doc-button:hover { | |
background: #007bff; | |
color: white; | |
} | |
/* Button container for all buttons */ | |
.button-container { | |
display: flex; | |
margin-top: 8px; | |
align-items: baseline; | |
justify-content: space-between; | |
} | |
.left-buttons { | |
display: flex; | |
gap: 8px; | |
align-items: baseline; | |
} | |
.right-buttons { | |
display: flex; | |
gap: 8px; | |
align-items: baseline; | |
} | |
.vote-button.upvote:hover { | |
background: #28a745; | |
color: white; | |
} | |
.vote-button.downvote:hover { | |
background: #dc3545; | |
color: white; | |
} | |
.vote-button.voted { | |
opacity: 1; | |
transform: scale(1.1); | |
} | |
.vote-button.upvote.voted { | |
background: #28a745; | |
color: white; | |
} | |
.vote-button.downvote.voted { | |
background: #dc3545; | |
color: white; | |
} | |
/* Additional dark mode support */ | |
@media (prefers-color-scheme: dark) { | |
.lf-title{ | |
color:#e0e0e0 !important; | |
} | |
#vote_radio label{ | |
border:1px solid #555; | |
background:#2a2a2a; | |
color:#e0e0e0; | |
} | |
#vote_radio input[type="radio"]:checked + label{ | |
background:#0066ff; | |
color:#fff; | |
border-color:#0066ff; | |
} | |
/* Progress bar dark mode */ | |
.gradio-container .progress-bar { | |
background: #2a2a2a !important; | |
border: 1px solid #555 !important; | |
} | |
.gradio-container .progress-bar .progress-text { | |
color: #e0e0e0 !important; | |
} | |
.gradio-container .progress-bar .progress-level { | |
background: #444 !important; | |
} | |
/* Copy button dark mode */ | |
.copy-button { | |
background: rgba(255, 255, 255, 0.1); | |
color: #ccc; | |
border-color: #555; | |
} | |
.copy-button:hover { | |
background: rgba(255, 255, 255, 0.2); | |
color: #fff; | |
} | |
.copy-button:active { | |
background: rgba(255, 255, 255, 0.3); | |
color: #fff; | |
} | |
/* Query input dark mode */ | |
#query_input textarea { | |
background: #2a2a2a !important; | |
color: #e0e0e0 !important; | |
} | |
#query_input textarea:focus { | |
border-color: #0066ff !important; | |
box-shadow: 0 0 0 2px rgba(0, 102, 255, 0.2) !important; | |
} | |
/* Vote buttons dark mode */ | |
.vote-button { | |
border-color: #555 !important; | |
background: #2a2a2a !important; | |
} | |
.vote-button.upvote { | |
color: #4ade80 !important; | |
border-color: #4ade80 !important; | |
} | |
.vote-button.downvote { | |
color: #f87171 !important; | |
border-color: #f87171 !important; | |
} | |
/* Doc button dark mode */ | |
.doc-button { | |
background: rgba(96, 165, 250, 0.1) !important; | |
border-color: #60a5fa !important; | |
color: #60a5fa !important; | |
} | |
.doc-button:hover { | |
background: #60a5fa !important; | |
color: #000 !important; | |
} | |
.vote-button.upvote:hover { | |
background: #4ade80 !important; | |
color: #000 !important; | |
} | |
.vote-button.downvote:hover { | |
background: #f87171 !important; | |
color: #000 !important; | |
} | |
.vote-button.upvote.voted { | |
background: #4ade80 !important; | |
color: #000 !important; | |
} | |
.vote-button.downvote.voted { | |
background: #f87171 !important; | |
color: #000 !important; | |
} | |
/* Vote feedback styling */ | |
#vote_feedback { | |
margin-top: 1rem; | |
text-align: center; | |
} | |
#vote_feedback .markdown { | |
background: #d4edda !important; | |
border: 1px solid #c3e6cb !important; | |
border-radius: 8px !important; | |
padding: 12px 16px !important; | |
color: #155724 !important; | |
font-weight: 500 !important; | |
margin: 0 !important; | |
transition: opacity 0.5s ease !important; | |
opacity: 1 !important; | |
} | |
/* Error styling for vote feedback */ | |
#vote_feedback .markdown:has-text("ERROR"), | |
#vote_feedback .markdown[data-error="true"] { | |
background: #f8d7da !important; | |
border-color: #f5c6cb !important; | |
color: #721c24 !important; | |
} | |
/* Dark mode for vote feedback */ | |
@media (prefers-color-scheme: dark) { | |
#vote_feedback .markdown { | |
background: #1e3a2e !important; | |
border-color: #2d5a3d !important; | |
color: #a7d4b4 !important; | |
} | |
/* Error styling in dark mode */ | |
#vote_feedback .markdown:has-text("ERROR"), | |
#vote_feedback .markdown[data-error="true"] { | |
background: #3d1a1a !important; | |
border-color: #5a2d2d !important; | |
color: #ff9999 !important; | |
} | |
} | |
} | |
""" | |
with gr.Blocks( | |
title="Lean Finder Retrieval", | |
css=CUSTOM_CSS, | |
head=""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.8/dist/katex.min.css"> | |
<script src="https://cdn.jsdelivr.net/npm/katex@0.16.8/dist/katex.min.js"></script> | |
<script> | |
function renderKatex() { | |
if (typeof katex === 'undefined') { | |
setTimeout(renderKatex, 500); | |
return; | |
} | |
const spans = document.querySelectorAll('span:not([data-katex-processed])'); | |
spans.forEach(function(span) { | |
const text = span.textContent; | |
if (text && text.includes('$')) { | |
span.setAttribute('data-katex-processed', 'true'); | |
let mathExpressions = []; | |
// Find inline math $...$ | |
let pos = 0; | |
while (pos < text.length) { | |
const start = text.indexOf('$', pos); | |
if (start === -1) break; | |
const end = text.indexOf('$', start + 1); | |
if (end === -1) break; | |
const isDisplayMath = (start > 0 && text[start - 1] === '$') || | |
(end < text.length - 1 && text[end + 1] === '$'); | |
if (!isDisplayMath) { | |
mathExpressions.push({ | |
start: start, | |
end: end + 1, | |
content: text.substring(start + 1, end), | |
type: 'inline' | |
}); | |
pos = end + 1; | |
} else { | |
pos = start + 1; | |
} | |
} | |
// Find display math $$...$$ | |
pos = 0; | |
while (pos < text.length) { | |
const start = text.indexOf('$$', pos); | |
if (start === -1) break; | |
const end = text.indexOf('$$', start + 2); | |
if (end === -1) break; | |
mathExpressions.push({ | |
start: start, | |
end: end + 2, | |
content: text.substring(start + 2, end), | |
type: 'display' | |
}); | |
pos = end + 2; | |
} | |
mathExpressions.sort((a, b) => a.start - b.start); | |
// Remove overlapping expressions | |
let filtered = []; | |
for (let expr of mathExpressions) { | |
let shouldAdd = true; | |
for (let j = 0; j < filtered.length; j++) { | |
const existing = filtered[j]; | |
if (!(expr.end <= existing.start || expr.start >= existing.end)) { | |
if (expr.type === 'display' && existing.type === 'inline') { | |
filtered.splice(j, 1); | |
j--; | |
} else { | |
shouldAdd = false; | |
break; | |
} | |
} | |
} | |
if (shouldAdd) filtered.push(expr); | |
} | |
if (filtered.length > 0) { | |
let newContent = document.createDocumentFragment(); | |
let lastPos = 0; | |
filtered.forEach(function(expr) { | |
if (expr.start > lastPos) { | |
newContent.appendChild(document.createTextNode(text.substring(lastPos, expr.start))); | |
} | |
try { | |
const mathElement = document.createElement('span'); | |
mathElement.innerHTML = katex.renderToString(expr.content, { | |
throwOnError: false, | |
displayMode: expr.type === 'display' | |
}); | |
newContent.appendChild(mathElement); | |
} catch (e) { | |
const fallback = expr.type === 'display' ? '$$' + expr.content + '$$' : '$' + expr.content + '$'; | |
newContent.appendChild(document.createTextNode(fallback)); | |
} | |
lastPos = expr.end; | |
}); | |
if (lastPos < text.length) { | |
newContent.appendChild(document.createTextNode(text.substring(lastPos))); | |
} | |
span.innerHTML = ''; | |
span.appendChild(newContent); | |
} | |
} | |
}); | |
} | |
setTimeout(renderKatex, 1000); | |
setTimeout(renderKatex, 3000); | |
setTimeout(renderKatex, 5000); | |
function copyToClipboard(text, button) { | |
if (navigator.clipboard && window.isSecureContext) { | |
navigator.clipboard.writeText(text).then(function() { | |
showCopySuccess(button); | |
}).catch(function() { | |
fallbackCopy(text, button); | |
}); | |
} else { | |
fallbackCopy(text, button); | |
} | |
} | |
function fallbackCopy(text, button) { | |
const textArea = document.createElement('textarea'); | |
textArea.value = text; | |
textArea.style.position = 'fixed'; | |
textArea.style.left = '-9999px'; | |
textArea.style.top = '-9999px'; | |
document.body.appendChild(textArea); | |
textArea.focus(); | |
textArea.select(); | |
try { | |
document.execCommand('copy'); | |
showCopySuccess(button); | |
} catch (err) { | |
showCopyError(button); | |
} | |
document.body.removeChild(textArea); | |
} | |
function showCopySuccess(button) { | |
const originalText = button.innerHTML; | |
button.innerHTML = 'β'; | |
button.style.color = '#28a745'; | |
setTimeout(function() { | |
button.innerHTML = originalText; | |
button.style.color = ''; | |
}, 1000); | |
} | |
function showCopyError(button) { | |
const originalText = button.innerHTML; | |
button.innerHTML = 'β'; | |
button.style.color = '#dc3545'; | |
setTimeout(function() { | |
button.innerHTML = originalText; | |
button.style.color = ''; | |
}, 1000); | |
} | |
function voteOnResultSafe(button, voteDecision) { | |
const formalStatement = button.getAttribute('data-formal'); | |
const query = button.getAttribute('data-query'); | |
const system = button.getAttribute('data-system'); | |
const rank = parseInt(button.getAttribute('data-rank')); | |
button.classList.add('voted'); | |
const otherButton = button.parentElement.querySelector('.vote-button:not(.' + (voteDecision === 'Upvote' ? 'upvote' : 'downvote') + ')'); | |
if (otherButton) { | |
otherButton.style.opacity = '0.3'; | |
otherButton.style.pointerEvents = 'none'; | |
} | |
const originalText = button.innerHTML; | |
button.innerHTML = voteDecision === 'Upvote' ? 'πβ' : 'πβ'; | |
setTimeout(function() { | |
button.innerHTML = originalText; | |
}, 2000); | |
const voteData = JSON.stringify({ | |
formal_statement: formalStatement, | |
vote_decision: voteDecision, | |
query: query, | |
system: system, | |
rank: rank | |
}); | |
const voteContainer = document.querySelector('#vote_trigger_input'); | |
if (voteContainer) { | |
const input = voteContainer.querySelector('textarea') || voteContainer.querySelector('input[type="text"]'); | |
if (input) { | |
input.value = voteData; | |
input.dispatchEvent(new Event('input', { bubbles: true, cancelable: true })); | |
input.dispatchEvent(new Event('change', { bubbles: true, cancelable: true })); | |
setTimeout(function() { | |
const feedbackElement = document.querySelector('#vote_feedback .markdown'); | |
if (feedbackElement) { | |
if (feedbackElement.textContent.includes('ERROR')) { | |
feedbackElement.setAttribute('data-error', 'true'); | |
} else { | |
feedbackElement.removeAttribute('data-error'); | |
} | |
} | |
}, 100); | |
setTimeout(function() { | |
const feedbackElement = document.querySelector('#vote_feedback .markdown'); | |
if (feedbackElement && feedbackElement.textContent.trim()) { | |
feedbackElement.style.opacity = '0'; | |
setTimeout(function() { | |
const input = voteContainer.querySelector('textarea') || voteContainer.querySelector('input[type="text"]'); | |
if (input) { | |
input.value = ''; | |
input.dispatchEvent(new Event('input', { bubbles: true })); | |
} | |
}, 500); | |
} | |
}, 4000); | |
} | |
} | |
} | |
function voteOnResult(formalStatement, voteDecision, query, system, rank, button) { | |
const tempButton = document.createElement('button'); | |
tempButton.setAttribute('data-formal', formalStatement); | |
tempButton.setAttribute('data-query', query); | |
tempButton.setAttribute('data-system', system); | |
tempButton.setAttribute('data-rank', rank); | |
tempButton.className = button.className; | |
tempButton.parentElement = button.parentElement; | |
voteOnResultSafe(tempButton, voteDecision); | |
} | |
setInterval(function() { | |
if (document.querySelector('table')) { | |
renderKatex(); | |
} | |
}, 3000); | |
</script> | |
""" | |
) as demo: | |
with gr.Column(elem_id="lf_header"): | |
gr.Image( | |
value="lean_finder_logo.png", | |
show_label=False, | |
interactive=False, | |
container=False, | |
show_download_button=False, | |
height=60, | |
elem_id="lf_logo", | |
) | |
gr.HTML('<h1 class="lf-title">Lean Finder: Semantic Search for Mathlib That Understands User Intents</h1>') | |
with gr.Accordion("Supported query types(click for details): Informalized statement, User question, Statement definition, Code snippet", open=False): | |
gr.Markdown(INSTRUCTIONS_MD) | |
with gr.Row(): | |
query_box = gr.Textbox(label="Query", lines=4, max_lines=20, | |
placeholder="Type your query here β¦", | |
elem_id="query_input") | |
topk_slider = gr.Slider(label="Number of results", minimum=1, maximum=50, step=1, value=5) | |
with gr.Column(): | |
mode_sel = gr.Radio(["Arena", "Normal"], value="Arena", label="Mode") | |
mode_description = gr.Markdown("Arena mode: Compare retrieval results from Lean Finder with another retriever and vote.") | |
run_btn = gr.Button("Retrieve") | |
with gr.Row(elem_id="vote_area"): | |
with gr.Column(elem_id="vote_radio_col"): | |
vote_radio = gr.Radio( | |
VOTE_UI, | |
label="Which result is better?", | |
visible=False, | |
elem_id="vote_radio" | |
) | |
submit_btn = gr.Button( | |
"Submit vote", | |
visible=False, | |
elem_id="submit_btn", | |
variant="primary" | |
) | |
vote_status = gr.Markdown("", visible=False, elem_id="vote_status") | |
results_html = gr.HTML() | |
# Hidden component for individual vote triggers | |
vote_trigger = gr.Textbox(visible=False, elem_id="vote_trigger_input") | |
vote_feedback = gr.Markdown("", visible=True, elem_id="vote_feedback") | |
# per-session state | |
st_query = gr.State("") | |
st_ret_a_js = gr.State({}) | |
st_ret_b_js = gr.State([]) | |
st_sys1_is_a = gr.State(True) # Track which system is A in current session | |
def retrieve(query: str, k: int, mode: str, progress=gr.Progress()): | |
query = query.strip() | |
if not query: | |
hide = gr.update(visible=False, value="") | |
return "<p>Please enter a query.</p>", query, {}, [], hide, hide, hide, True | |
try: | |
sys1_json = _call_retriever_system1({"inputs": query, "top_k": k}).get("results", []) | |
except RuntimeError: | |
payload = {"inputs": query, "top_k": k} | |
hide = gr.update(visible=False, value="") | |
progress(0, desc="Please wait about 2 minutes. The Lean Finder service is starting up. Results will be displayed automatically once Lean Finder is ready.") | |
start_time = time.time() | |
timeout_duration = 300 | |
progress_phase1_duration = 120 | |
last_progress_update = 0 | |
while time.time() - start_time < timeout_duration: | |
elapsed_time = time.time() - start_time | |
if elapsed_time - last_progress_update >= 10: | |
if elapsed_time < progress_phase1_duration: | |
progress(elapsed_time / progress_phase1_duration, desc="Please wait about 2 minutes. The Lean Finder service is starting up. Results will be displayed automatically once Lean Finder is ready.") | |
else: | |
progress(1.0, desc="Experiencing a slow start. Please wait a little longer β the content will be displayed once Lean Finder is ready.") | |
last_progress_update = elapsed_time | |
time.sleep(1) | |
try: | |
sys1_json = _call_retriever_system1(payload).get("results", []) | |
break | |
except RuntimeError: | |
continue | |
else: | |
error_message = "<div style='background-color: #f8d7da; border: 1px solid #f5c6cb; padding: 10px; margin-bottom: 15px; border-radius: 5px; color: #721c24;'><strong>Error:</strong> Lean Finder service is currently unavailable. Please contact the maintainer of this project at mike_lu@sfu.ca</div>" | |
return error_message, query, {}, [], hide, hide, hide, True | |
if mode == "Normal": | |
sys1_html = _render_system1_results(sys1_json, title="Lean Finder", query=query) | |
hide = gr.update(visible=False, value="") | |
return sys1_html, query, sys1_json, [], hide, hide, hide, True | |
try: | |
sys2_json = _call_retriever_system2(query, k) | |
except RuntimeError: | |
error_message = "<div style='background-color: #fff3cd; border: 1px solid #ffeaa7; padding: 10px; margin-bottom: 15px; border-radius: 5px; color: #856404;'><strong>Notice:</strong> The other retriever Lean Search is currently unavailable, falling back to Normal mode.</div>" | |
sys1_html = _render_system1_results(sys1_json, title="Lean Finder", query=query) | |
fallback_html = error_message + sys1_html | |
hide = gr.update(visible=False, value="") | |
return fallback_html, query, sys1_json, [], hide, hide, hide, True | |
sys1_is_a = random.choice([True, False]) | |
if sys1_is_a: | |
ret_a_json, ret_b_json = sys1_json, sys2_json | |
ret_a_html = _render_system1_results(ret_a_json, title="Retriever A", query=query) | |
ret_b_html = _render_system2_results(ret_b_json, title="Retriever B", query=query) | |
else: | |
ret_a_json, ret_b_json = sys2_json, sys1_json | |
ret_a_html = _render_system2_results(ret_a_json, title="Retriever A", query=query) | |
ret_b_html = _render_system1_results(ret_b_json, title="Retriever B", query=query) | |
page = ( | |
"<div style='display:flex; gap:0.5rem;'>" | |
f"<div style='flex:1 1 0;'>{ret_a_html}</div>" | |
f"<div style='flex:1 1 0;'>{ret_b_html}</div>" | |
"</div>" | |
) | |
show_radio = gr.update(visible=True, value=None) | |
hide_status = gr.update(visible=False, value="") | |
show_btn = gr.update(visible=True) | |
return page, query, ret_a_json, ret_b_json, show_radio, show_btn, hide_status, sys1_is_a | |
run_btn.click( | |
retrieve, | |
inputs=[query_box, topk_slider, mode_sel], | |
outputs=[results_html, st_query, st_ret_a_js, st_ret_b_js, | |
vote_radio, submit_btn, vote_status, st_sys1_is_a], | |
) | |
def _reset_ui_on_mode_change(mode): | |
if mode == "Arena": | |
description = "Arena mode: Compare retrieval results from Lean Finder with another retriever and vote." | |
else: | |
description = "Normal mode: Standard retrieval from Lean Finder." | |
return ( | |
"", | |
gr.update(visible=False, value=None), | |
gr.update(visible=False), | |
gr.update(visible=False, value=""), | |
description, | |
) | |
mode_sel.change( | |
_reset_ui_on_mode_change, | |
inputs=mode_sel, | |
outputs=[results_html, vote_radio, submit_btn, vote_status, mode_description], | |
) | |
submit_btn.click( | |
_save_vote, | |
inputs=[vote_radio, st_query, st_ret_a_js, st_ret_b_js, st_sys1_is_a], | |
outputs=vote_status | |
) | |
def handle_individual_vote(vote_data: str, query: str, ret_a_js: Dict[str, Any], ret_b_js: List[Dict[str, Any]], sys1_is_a: bool) -> str: | |
if not vote_data: | |
return "" | |
try: | |
data = json.loads(vote_data) | |
payload = { | |
"retriever_a": ret_a_js, | |
"retriever_b": ret_b_js, | |
"sys1_is_a": sys1_is_a, | |
"individual_vote": data | |
} | |
_save_individual_vote( | |
data["formal_statement"], | |
data["vote_decision"], | |
data["query"], | |
data["system"], | |
data["rank"], | |
payload | |
) | |
return f"**{data['vote_decision']} recorded!** (Rank {data['rank']})" | |
except (json.JSONDecodeError, KeyError) as e: | |
return f"Error saving vote: {str(e)}" | |
except Exception as e: | |
return f"ERROR: {str(e)}" | |
vote_trigger.change( | |
handle_individual_vote, | |
inputs=[vote_trigger, st_query, st_ret_a_js, st_ret_b_js, st_sys1_is_a], | |
outputs=vote_feedback | |
) | |
if __name__ == "__main__": | |
demo.launch() |