Spaces:
Running
on
Zero
Running
on
Zero
File size: 23,208 Bytes
d6be5fa f6e2d8a d6be5fa 82453cf f6e2d8a 08fac87 84ccc57 f6e2d8a 82453cf c70eaca 82453cf c70eaca d6be5fa 9ac80a4 b7a3bb3 9ac80a4 7320bf1 f6e2d8a b42e964 f6e2d8a b42e964 f6e2d8a f8c72d3 08fac87 f8c72d3 08fac87 b42e964 f8c72d3 82453cf f8c72d3 142bd00 f8c72d3 142bd00 f8c72d3 17ad0bb f53e324 17ad0bb f8c72d3 08fac87 f8c72d3 32c7e5a 9d0646a 32c7e5a f8c72d3 7565b92 f8c72d3 08fac87 3b06efd f8c72d3 3b06efd 08fac87 d6be5fa 503a0b6 7320bf1 503a0b6 08fac87 100d2c7 f8c72d3 82453cf f8c72d3 84ccc57 f027363 08fac87 82453cf 08fac87 3472410 3b06efd 08fac87 17ad0bb 08fac87 03db0de 08fac87 03db0de 08fac87 b7a3bb3 08fac87 03db0de 08fac87 b7a3bb3 08fac87 03db0de 08fac87 3575a77 17ad0bb 08fac87 17ad0bb 08fac87 03db0de 08fac87 17ad0bb 08fac87 03db0de 08fac87 f8c72d3 08fac87 e4c1af6 f8c72d3 27b6f54 6f5111d f8c72d3 08fac87 e4c1af6 6f5111d e4c1af6 3b06efd 08fac87 404f284 08fac87 100d2c7 08fac87 32c7e5a 3b06efd 08fac87 27b6f54 100d2c7 27b6f54 08fac87 bf82de1 7565b92 bf82de1 08fac87 bf82de1 c70eaca 08fac87 bf82de1 08fac87 bf82de1 f53e324 03db0de 08fac87 3b06efd 08fac87 3472410 08fac87 03db0de 82453cf 03db0de 08fac87 03db0de 08fac87 404f284 84ccc57 08fac87 5f82b5a 100d2c7 08fac87 100d2c7 32c7e5a 03db0de 82453cf 03db0de 5f82b5a 08fac87 32c7e5a 03db0de 3b06efd 84ccc57 f6e2d8a 08fac87 03db0de 6f5111d 3b06efd 03db0de 08fac87 f53e324 8a36ac7 e4c1af6 f53e324 6f5111d f53e324 555a40e e00a4f0 f53e324 e4c1af6 82453cf e4c1af6 03db0de 3472410 08fac87 100d2c7 f6e2d8a 3472410 bf82de1 08fac87 bf82de1 3472410 03db0de 100d2c7 03db0de f53e324 3472410 03db0de f53e324 3472410 f52b66d 03db0de f53e324 e4c1af6 3472410 03db0de 3472410 08fac87 100d2c7 32c7e5a 08fac87 32c7e5a 08fac87 3472410 08fac87 3472410 08fac87 03db0de 100d2c7 f53e324 03db0de 503a0b6 f53e324 6f5111d 08fac87 3472410 f8c72d3 3472410 08fac87 3472410 08fac87 3472410 08fac87 503a0b6 f8c72d3 08fac87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 |
from langgraph.checkpoint.memory import MemorySaver
from huggingface_hub import snapshot_download
from dotenv import load_dotenv
from datetime import datetime
import gradio as gr
import spaces
import torch
import uuid
import ast
import os
import re
# Local modules
from main import GetChatModel, openai_model, model_id
from util import get_sources, get_start_end_months
from retriever import db_dir, embedding_model_id
from mods.tool_calling_llm import extract_think
from data import download_data, extract_data
from graph import BuildGraph
# Set environment variables
load_dotenv(dotenv_path=".env", override=True)
# Hide BM25S progress bars
os.environ["DISABLE_TQDM"] = "true"
# Download model snapshots from Hugging Face Hub
if torch.cuda.is_available():
print(f"Downloading checkpoints for {model_id}...")
ckpt_dir = snapshot_download(model_id, local_dir_use_symlinks=False)
print(f"Using checkpoints from {ckpt_dir}")
print(f"Downloading checkpoints for {embedding_model_id}...")
embedding_ckpt_dir = snapshot_download(
embedding_model_id, local_dir_use_symlinks=False
)
print(f"Using embedding checkpoints from {embedding_ckpt_dir}")
else:
ckpt_dir = None
embedding_ckpt_dir = None
# Download and extract data if data directory is not present
if not os.path.isdir(db_dir):
print("Downloading data ... ", end="")
download_data()
print("done!")
print("Extracting data ... ", end="")
extract_data()
print("done!")
# Global setting for search type
search_type = "hybrid"
# Global variables for LangChain graph: use dictionaries to store user-specific instances
# https://www.gradio.app/guides/state-in-blocks
graph_instances = {"local": {}, "remote": {}}
def cleanup_graph(request: gr.Request):
timestamp = datetime.now().replace(microsecond=0).isoformat()
if request.session_hash in graph_instances["local"]:
del graph_instances["local"][request.session_hash]
print(f"{timestamp} - Delete local graph for session {request.session_hash}")
if request.session_hash in graph_instances["remote"]:
del graph_instances["remote"][request.session_hash]
print(f"{timestamp} - Delete remote graph for session {request.session_hash}")
def append_content(chunk_messages, history, thinking_about):
"""Append thinking and non-thinking content to chatbot history"""
if chunk_messages.content:
think_text, post_think = extract_think(chunk_messages.content)
# Show thinking content in "metadata" message
if think_text:
history.append(
gr.ChatMessage(
role="assistant",
content=think_text,
metadata={"title": f"π§ Thinking about the {thinking_about}"},
)
)
if not post_think and not chunk_messages.tool_calls:
gr.Warning("Response may be incomplete", title="Thinking-only response")
# Display non-thinking content
if post_think:
history.append(gr.ChatMessage(role="assistant", content=post_think))
return history
def run_workflow(input, history, compute_mode, thread_id, session_hash):
"""The main function to run the chat workflow"""
# Error if user tries to run local mode without GPU
if compute_mode == "local":
if not torch.cuda.is_available():
raise gr.Error(
"Local mode requires GPU.",
print_exception=False,
)
# Get graph instance
graph = graph_instances[compute_mode].get(session_hash)
if graph is None:
# Notify when we're loading the local model because it takes some time
if compute_mode == "local":
gr.Info(
f"Please wait for the local model to load",
title=f"Model loading...",
)
# Get the chat model and build the graph
chat_model = GetChatModel(compute_mode, ckpt_dir)
graph_builder = BuildGraph(
chat_model,
compute_mode,
search_type,
embedding_ckpt_dir=embedding_ckpt_dir,
)
# Compile the graph with an in-memory checkpointer
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)
# Set global graph for compute mode
graph_instances[compute_mode][session_hash] = graph
# ISO 8601 timestamp with local timezone information without microsecond
timestamp = datetime.now().replace(microsecond=0).isoformat()
print(f"{timestamp} - Set {compute_mode} graph for session {session_hash}")
# Notify when model finishes loading
gr.Success(f"{compute_mode}", duration=4, title=f"Model loaded!")
else:
timestamp = datetime.now().replace(microsecond=0).isoformat()
print(f"{timestamp} - Get {compute_mode} graph for session {session_hash}")
# print(f"Using thread_id: {thread_id}")
# Display the user input in the chatbot
history.append(gr.ChatMessage(role="user", content=input))
# Return the message history and empty lists for emails and citations texboxes
yield history, [], []
# Stream graph steps for a single input
# https://langchain-ai.lang.chat/langgraph/reference/graphs/#langgraph.graph.state.CompiledStateGraph
for step in graph.stream(
# Appends the user input to the graph state
{"messages": [{"role": "user", "content": input}]},
config={"configurable": {"thread_id": thread_id}},
):
# Get the node name and output chunk
node, chunk = next(iter(step.items()))
if node == "query":
# Get the message (AIMessage class in LangChain)
chunk_messages = chunk["messages"]
# Append thinking and non-thinking messages (if present)
history = append_content(chunk_messages, history, thinking_about="query")
# Look for tool calls
if chunk_messages.tool_calls:
# Loop over tool calls
for tool_call in chunk_messages.tool_calls:
# Show the tool call with arguments used
args = tool_call["args"]
content = args["search_query"] if "search_query" in args else ""
start_year = args["start_year"] if "start_year" in args else None
end_year = args["end_year"] if "end_year" in args else None
if start_year or end_year:
content = f"{content} ({start_year or ''} - {end_year or ''})"
if "months" in args:
content = f"{content} {args['months']}"
history.append(
gr.ChatMessage(
role="assistant",
content=content,
metadata={"title": f"π Running tool {tool_call['name']}"},
)
)
yield history, [], []
if node == "retrieve_emails":
chunk_messages = chunk["messages"]
# Loop over tool calls
count = 0
retrieved_emails = []
for message in chunk_messages:
count += 1
# Get the retrieved emails as a list
email_list = message.content.replace(
"### Retrieved Emails:\n\n", ""
).split("--- --- --- --- Next Email --- --- --- ---\n\n")
# Get the list of source files (e.g. R-help/2024-December.txt) for retrieved emails
month_list = [text.splitlines()[0] for text in email_list]
# Format months (e.g. 2024-December) into text
month_text = (
", ".join(month_list).replace("R-help/", "").replace(".txt", "")
)
# Get the number of retrieved emails
n_emails = len(email_list)
title = f"π Retrieved {n_emails} emails"
if email_list[0] == "### No emails were retrieved":
title = "β Retrieved 0 emails"
history.append(
gr.ChatMessage(
role="assistant",
content=month_text,
metadata={"title": title},
)
)
# Format the retrieved emails with Tool Call heading
retrieved_emails.append(
message.content.replace(
"### Retrieved Emails:\n\n",
f"### ### ### ### Tool Call {count} ### ### ### ###\n\n",
)
)
# Combine all the Tool Call results
retrieved_emails = "\n\n".join(retrieved_emails)
yield history, retrieved_emails, []
if node == "answer":
# Append messages (thinking and non-thinking) to history
chunk_messages = chunk["messages"]
history = append_content(chunk_messages, history, thinking_about="answer")
# None is used for no change to the retrieved emails textbox
yield history, None, []
if node == "answer_with_citations":
# Parse the message for the answer and citations
chunk_messages = chunk["messages"][0]
try:
answer, citations = ast.literal_eval(chunk_messages.content)
except:
# In case we got an answer without citations
answer = chunk_messages.content
citations = None
history.append(gr.ChatMessage(role="assistant", content=answer))
yield history, None, citations
def to_workflow(request: gr.Request, *args):
"""Wrapper function to call function with or without @spaces.GPU"""
input = args[0]
compute_mode = args[2]
# Add session_hash to arguments
new_args = args + (request.session_hash,)
if compute_mode == "local":
# Call the workflow function with the @spaces.GPU decorator
for value in run_workflow_local(*new_args):
yield value
if compute_mode == "remote":
for value in run_workflow_remote(*new_args):
yield value
@spaces.GPU(duration=100)
def run_workflow_local(*args):
for value in run_workflow(*args):
yield value
def run_workflow_remote(*args):
for value in run_workflow(*args):
yield value
# Custom CSS for bottom alignment
css = """
.row-container {
display: flex;
align-items: flex-end; /* Align components at the bottom */
gap: 10px; /* Add spacing between components */
}
"""
with gr.Blocks(
title="R-help-chat",
# Noto Color Emoji gets a nice-looking Unicode Character βπ·β (U+1F1F7) on Chrome
theme=gr.themes.Soft(
font=[
"ui-sans-serif",
"system-ui",
"sans-serif",
"Apple Color Emoji",
"Segoe UI Emoji",
"Segoe UI Symbol",
"Noto Color Emoji",
]
),
css=css,
) as demo:
# -----------------
# Define components
# -----------------
compute_mode = gr.Radio(
choices=[
"local",
"remote",
],
# Default to remote because it provides a better first impression for most people
# value=("local" if torch.cuda.is_available() else "remote"),
value="remote",
label="Compute Mode",
info="NOTE: remote mode **does not** use ZeroGPU",
render=False,
)
loading_data = gr.Textbox(
"Please wait for the email database to be downloaded and extracted.",
max_lines=0,
label="Loading Data",
visible=False,
render=False,
)
downloading = gr.Textbox(
max_lines=1,
label="Downloading Data",
visible=False,
render=False,
)
extracting = gr.Textbox(
max_lines=1,
label="Extracting Data",
visible=False,
render=False,
)
missing_data = gr.Textbox(
value="Email database is missing. Try reloading this page. If the problem persists, please contact the maintainer.",
lines=1,
label="Error downloading or extracting data",
visible=False,
render=False,
)
chatbot = gr.Chatbot(
type="messages",
show_label=False,
avatar_images=(
None,
(
"images/cloud.png"
if compute_mode.value == "remote"
else "images/chip.png"
),
),
show_copy_all_button=True,
render=False,
)
# Modified from gradio/chat_interface.py
input = gr.Textbox(
show_label=False,
label="Message",
placeholder="Type a message...",
scale=7,
autofocus=True,
submit_btn=True,
render=False,
)
emails_textbox = gr.Textbox(
label="Retrieved Emails",
info="Tip: Look for 'Tool Call' and 'Next Email' separators. Quoted lines (starting with '>') are removed before indexing.",
lines=10,
visible=False,
render=False,
)
citations_textbox = gr.Textbox(
label="Citations",
lines=2,
visible=False,
render=False,
)
# ------------
# Set up state
# ------------
def generate_thread_id():
"""Generate a new thread ID"""
thread_id = uuid.uuid4()
# print(f"Generated thread_id: {thread_id}")
return thread_id
# Define thread_id variable
thread_id = gr.State(generate_thread_id())
# Define states for the output textboxes
retrieved_emails = gr.State([])
citations_text = gr.State([])
# ------------------
# Make the interface
# ------------------
def get_intro_text():
intro = f"""<!-- # π€ R-help-chat -->
<!-- Get AI-powered answers about R programming backed by email retrieval. -->
## π·π€π¬ R-help-chat
**Chat with the [R-help mailing list archives](https://stat.ethz.ch/pipermail/r-help/).**
An LLM turns your question into a search query, including year ranges and months, and generates an answer from the retrieved emails.
You can ask follow-up questions with the chat history as context.
β‘οΈ To clear the history and start a new chat, press the ποΈ clear button.
**_Answers may be incorrect._**
"""
return intro
def get_status_text(compute_mode):
if compute_mode == "remote":
status_text = f"""
π Now in **remote** mode, using the OpenAI API<br>
β οΈ **_Privacy Notice_**: Data sharing with OpenAI is enabled<br>
β¨ text-embedding-3-small and {openai_model}<br>
π See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
"""
if compute_mode == "local":
status_text = f"""
π Now in **local** mode, using ZeroGPU hardware<br>
β Response time is about one minute<br>
β¨ [{embedding_model_id.split("/")[-1]}](https://huggingface.co/{embedding_model_id}) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
π See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
"""
return status_text
def get_info_text():
try:
# Get source files for each email and start and end months from database
sources = get_sources()
start, end = get_start_end_months(sources)
except:
# If database isn't ready, put in empty values
sources = []
start = None
end = None
info_text = f"""
**Database:** {len(sources)} emails from {start} to {end}.
**Features:** RAG, today's date, hybrid search (dense+sparse), multiple retrievals, citations output (remote), chat memory.
**Tech:** LangChain + Hugging Face + Gradio; ChromaDB and BM25S-based retrievers.<br>
"""
return info_text
def get_example_questions(compute_mode, as_dataset=True):
"""Get example questions based on compute mode"""
questions = [
# "What is today's date?",
"Summarize emails from the most recent two months",
"Show me code examples using plotmath",
"When was has.HLC mentioned?",
"Who reported installation problems in 2023-2024?",
]
## Remove "/think" from questions in remote mode
# if compute_mode == "remote":
# questions = [q.replace(" /think", "") for q in questions]
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
def get_multi_tool_questions(compute_mode, as_dataset=True):
"""Get multi-tool example questions based on compute mode"""
questions = [
"Differences between lapply and for loops",
"Discuss pipe operator usage in 2022, 2023, and 2024",
]
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
def get_multi_turn_questions(compute_mode, as_dataset=True):
"""Get multi-turn example questions based on compute mode"""
questions = [
"Lookup emails that reference bugs.r-project.org in 2025",
"Did the authors you cited report bugs before 2025?",
]
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
with gr.Row():
# Left column: Intro, Compute, Chat
with gr.Column(scale=2):
with gr.Row(elem_classes=["row-container"]):
with gr.Column(scale=2):
intro = gr.Markdown(get_intro_text())
with gr.Column(scale=1):
compute_mode.render()
with gr.Group() as chat_interface:
chatbot.render()
input.render()
# Render textboxes for data loading progress
loading_data.render()
downloading.render()
extracting.render()
missing_data.render()
# Right column: Info, Examples
with gr.Column(scale=1):
status = gr.Markdown(get_status_text(compute_mode.value))
with gr.Accordion("βΉοΈ More Info", open=False):
info = gr.Markdown(get_info_text())
with gr.Accordion("π‘ Examples", open=True):
# Add some helpful examples
example_questions = gr.Examples(
examples=get_example_questions(
compute_mode.value, as_dataset=False
),
inputs=[input],
label="Click an example to fill the message box",
)
multi_tool_questions = gr.Examples(
examples=get_multi_tool_questions(
compute_mode.value, as_dataset=False
),
inputs=[input],
label="Multiple retrievals",
)
multi_turn_questions = gr.Examples(
examples=get_multi_turn_questions(
compute_mode.value, as_dataset=False
),
inputs=[input],
label="Asking follow-up questions",
)
# Bottom row: retrieved emails and citations
with gr.Row():
with gr.Column(scale=2):
emails_textbox.render()
with gr.Column(scale=1):
citations_textbox.render()
# -------------
# App functions
# -------------
def value(value):
"""Return updated value for a component"""
return gr.update(value=value)
def set_avatar(compute_mode):
if compute_mode == "remote":
image_file = "images/cloud.png"
if compute_mode == "local":
image_file = "images/chip.png"
return gr.update(
avatar_images=(
None,
image_file,
),
)
def change_visibility(visible):
"""Return updated visibility state for a component"""
return gr.update(visible=visible)
def update_textbox(content, textbox):
if content is None:
# Keep the content of the textbox unchanged
return textbox, change_visibility(True)
elif content == []:
# Blank out the textbox
return "", change_visibility(False)
else:
# Display the content in the textbox
return content, change_visibility(True)
# --------------
# Event handlers
# --------------
# Start a new thread when the user presses the clear (trash) button
# https://github.com/gradio-app/gradio/issues/9722
chatbot.clear(generate_thread_id, outputs=[thread_id], api_name=False)
def clear_component(component):
"""Return cleared component"""
return component.clear()
compute_mode.change(
# Start a new thread
generate_thread_id,
outputs=[thread_id],
api_name=False,
).then(
# Focus textbox by updating the textbox with the current value
lambda x: gr.update(value=x),
[input],
[input],
api_name=False,
).then(
# Change the app status text
get_status_text,
[compute_mode],
[status],
api_name=False,
).then(
# Clear the chatbot history
clear_component,
[chatbot],
[chatbot],
api_name=False,
).then(
# Change the chatbot avatar
set_avatar,
[compute_mode],
[chatbot],
api_name=False,
)
input.submit(
# Submit input to the chatbot
to_workflow,
[input, chatbot, compute_mode, thread_id],
[chatbot, retrieved_emails, citations_text],
api_name=False,
)
retrieved_emails.change(
# Update the emails textbox
update_textbox,
[retrieved_emails, emails_textbox],
[emails_textbox, emails_textbox],
api_name=False,
)
citations_text.change(
# Update the citations textbox
update_textbox,
[citations_text, citations_textbox],
[citations_textbox, citations_textbox],
api_name=False,
)
chatbot.clear(
# Focus textbox when the chatbot is cleared
lambda x: gr.update(value=x),
[input],
[input],
api_name=False,
)
# Clean up graph instances when page is closed/refreshed
demo.unload(cleanup_graph)
if __name__ == "__main__":
# Set allowed_paths to serve chatbot avatar images
current_directory = os.getcwd()
allowed_paths = [current_directory + "/images"]
# Launch the Gradio app
demo.launch(allowed_paths=allowed_paths)
|