bebechien's picture
Upload folder using huggingface_hub
9a8c5bf verified
import gradio as gr
import spaces
import torch
from huggingface_hub import login
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, TextIteratorStreamer
from threading import Thread
from config import (
HF_TOKEN, EMBEDDING_MODEL_ID, LLM_MODEL_ID, DEFAULT_MESSAGE_NO_MATCH, get_all_game_data,
BASE_SIMILARITY_THRESHOLD, FOLLOWUP_SIMILARITY_THRESHOLD,
silksong_theme, silksong_css,
)
class ChatContext:
"""Holds the conversational state, including the current context and thresholds."""
def __init__(self):
self.context_index = -1
self.base_similarity = BASE_SIMILARITY_THRESHOLD
self.followup_similarity = FOLLOWUP_SIMILARITY_THRESHOLD
print("Logging into Hugging Face Hub...")
login(token=HF_TOKEN)
print("Initializing embedding model...")
embedding_model = SentenceTransformer(EMBEDDING_MODEL_ID)
print("Initializing language model...")
llm_pipeline = pipeline(
"text-generation",
model=LLM_MODEL_ID,
device_map="auto",
dtype="auto",
)
knowledge_base = get_all_game_data(embedding_model)
def _select_content(title: str) -> list[dict]:
"""Helper to safely get the knowledge base for a specific title."""
return knowledge_base.get(title, [])
@torch.no_grad()
def find_best_context(query: str, contents: list[dict], similarity_threshold: float) -> int:
"""Finds the most relevant document index based on semantic similarity."""
if not query or not contents:
return -1
query_embedding = embedding_model.encode(query, prompt_name="query", convert_to_tensor=True).to(embedding_model.device)
try:
# Stack pre-computed tensors from our knowledge base
contents_embeddings = torch.stack([item["embedding"] for item in contents]).to(embedding_model.device)
except (RuntimeError, IndexError, TypeError) as e:
print(f"Warning: Could not stack content embeddings. Error: {e}")
return -1
# Compute cosine similarity between the 1 query embedding and N content embeddings
similarities = util.pytorch_cos_sim(query_embedding, contents_embeddings)
if similarities.numel() == 0:
print("Warning: Similarity computation returned an empty tensor.")
return -1
# Get the index and score of the top match
best_index = similarities.argmax().item()
best_score = similarities[0, best_index].item()
print(f"Best score: {best_score:.4f} (Threshold: {similarity_threshold})")
if best_score >= similarity_threshold:
print(f"Using \"{contents[best_index]['metadata']['source']}\"...")
return best_index
print("No context met the similarity threshold.")
return -1
@spaces.GPU
def respond(message: str, history: list, title: str, chat_context: ChatContext):
"""Generates a streaming response from the LLM based on the best context found."""
default_threshold = chat_context.base_similarity
followup_threshold = chat_context.followup_similarity
contents = _select_content(title)
if not contents:
print(f"No content found for {title}")
chat_context.context_index = -1 # Return -1 to reset context
yield DEFAULT_MESSAGE_NO_MATCH, chat_context
return
if len(history) == 0:
# Clear context on a new conversation
print("New conversation started. Clearing context.")
chat_context.context_index = -1
# Determine threshold: Use follow-up ONLY if we have a valid previous context.
similarity_threshold = followup_threshold if chat_context.context_index != -1 else default_threshold
print(f"Using {'follow-up' if chat_context.context_index != -1 else 'default'} threshold: {similarity_threshold}")
# Find the best new context based on the current message
found_context_index = find_best_context(message, contents, similarity_threshold)
if found_context_index >= 0:
chat_context.context_index = found_context_index # A new, relevant context was found and set
elif chat_context.context_index >= 0:
# PASS: A follow-up question, but no new context. Reuse the old one.
print("No new context found, reusing previous context for follow-up.")
else:
# FAILURE: No new context was found AND no previous context exists.
print("No context found and no previous context. Yielding no match.")
yield DEFAULT_MESSAGE_NO_MATCH, chat_context
return
system_prompt = f"Answer the following QUESTION based only on the CONTEXT provided. If the answer cannot be found in the CONTEXT, write \"{DEFAULT_MESSAGE_NO_MATCH}\"\n---\nCONTEXT:\n{contents[chat_context.context_index]['text']}\n"
user_prompt = f"QUESTION:\n{message}"
messages = [{"role": "system", "content": system_prompt}]
# Add previous turns (history) after the system prompt but before the current question
messages.extend(history)
messages.append({"role": "user", "content": user_prompt})
# Debug print the conversation being sent (excluding the large system prompt)
for item in messages[1:]:
print(f"[{item['role']}] {item['content']}")
streamer = TextIteratorStreamer(llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True)
thread = Thread(
target=llm_pipeline,
kwargs=dict(
text_inputs=messages,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
temperature=0.7,
)
)
thread.start()
response = ""
for new_text in streamer:
response += new_text
# Yield the partial response AND the current state
yield response, chat_context
print(f"[assistant] {response}")
# --- GRADIO UI ---
# Defines the web interface for the chatbot.
@staticmethod
def on_title_changed(context_state: ChatContext) -> tuple[str, ChatContext]:
"""Resets the context display and state when the game is changed."""
context_state.context_index = -1
return """<div class="context">Context: None</div>""", context_state
@staticmethod
def on_sim_changed(context_state: ChatContext, base_sim: float, followup_sim: float) -> ChatContext:
"""Updates the similarity thresholds in the context state."""
context_state.base_similarity = base_sim
context_state.followup_similarity = followup_sim
return context_state
gr.set_static_paths(paths=["assets/"])
with gr.Blocks(theme=silksong_theme, css=silksong_css) as demo:
def on_context_changed(context_state: ChatContext, title: str) -> str:
"""Updates the HTML context display when the context_index state changes."""
context_index = context_state.context_index
if context_index < 0:
return """<div class="context">Context: None</div>"""
contents = _select_content(title)
if not contents or context_index >= len(contents):
return """<div class="context">Context: Error</div>"""
url = contents[context_index]['metadata']['source']
title = contents[context_index]['metadata']['title']
return f"""<div class="context">Context: <a href="{url}" target="_blank">{title}</a></div>"""
gr.HTML("""
<div class="header-text">
<h1>A Weaver's Counsel</h1>
<p>Speak, little traveler. What secrets of Pharloom do you seek?</p>
<p style="font-style: italic;">(Note: This bot has a limited knowledge.)</p>
</div>
""")
game_title = gr.Dropdown(["Hollow Knight", "Silksong"], label="Game", value="Silksong")
output = gr.HTML("""<div class="context">Context: None</div>""")
# Link the state object to the UI elements
context_state = gr.State(ChatContext())
context_state.change(on_context_changed, [context_state, game_title], output)
game_title.change(on_title_changed, context_state, [output, context_state])
gr.ChatInterface(
respond,
type="messages",
chatbot=gr.Chatbot(type="messages", label=LLM_MODEL_ID),
textbox=gr.Textbox(placeholder="Ask about the haunted kingdom...", container=False, submit_btn=True, scale=7),
additional_inputs=[
game_title,
context_state, ### Pass the state object as an input
],
additional_outputs=[context_state], ### Receive the updated state as an output
examples=[
["Where can I find the Moorwing?", "Silksong"],
["Who is the voice of Lace?", "Silksong"],
["How can I beat the False Knight?", "Hollow Knight"],
["Any achievement for Hornet Protector?", "Hollow Knight"],
],
cache_examples=False,
)
base_sim = gr.Slider(minimum=0.1, maximum=1.0, value=BASE_SIMILARITY_THRESHOLD, step=0.1, label="Base Similarity Threshold")
followup_sim = gr.Slider(minimum=0.1, maximum=1.0, value=FOLLOWUP_SIMILARITY_THRESHOLD, step=0.1, label="Similarity Threshold with follow-up questions (multi-turn)")
base_sim.release(on_sim_changed, [context_state, base_sim, followup_sim], context_state)
followup_sim.release(on_sim_changed, [context_state, base_sim, followup_sim], context_state)
gr.HTML("""
<div class="disclaimer">
<p><strong>Disclaimer:</strong></p>
<ul style="list-style: none; padding: 0;">
<li>This is a fan-made personal demonstration and not affiliated with any organization.<br>The bot is for entertainment purposes only.</li>
<li>Factual information is sourced from the <a href="https://hollowknight.wiki" target="_blank">Hollow Knight Wiki</a>.<br>Content is available under <a href="https://creativecommons.org/licenses/by-sa/3.0/" target="_blank">Commons Attribution-ShareAlike</a> unless otherwise noted.</li>
<li>Built by <a href="https://huggingface.co/bebechien" target="_blank">bebechien</a> with a 💖 for the world of Hollow Knight.</li>
</ul>
</div>
""")
if __name__ == "__main__":
print("Launching Gradio demo...")
demo.launch()