Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import torch | |
| from huggingface_hub import login | |
| from sentence_transformers import SentenceTransformer, util | |
| from transformers import pipeline, TextIteratorStreamer | |
| from threading import Thread | |
| from config import ( | |
| HF_TOKEN, EMBEDDING_MODEL_ID, LLM_MODEL_ID, DEFAULT_MESSAGE_NO_MATCH, get_all_game_data, | |
| BASE_SIMILARITY_THRESHOLD, FOLLOWUP_SIMILARITY_THRESHOLD, | |
| silksong_theme, silksong_css, | |
| ) | |
| class ChatContext: | |
| """Holds the conversational state, including the current context and thresholds.""" | |
| def __init__(self): | |
| self.context_index = -1 | |
| self.base_similarity = BASE_SIMILARITY_THRESHOLD | |
| self.followup_similarity = FOLLOWUP_SIMILARITY_THRESHOLD | |
| print("Logging into Hugging Face Hub...") | |
| login(token=HF_TOKEN) | |
| print("Initializing embedding model...") | |
| embedding_model = SentenceTransformer(EMBEDDING_MODEL_ID) | |
| print("Initializing language model...") | |
| llm_pipeline = pipeline( | |
| "text-generation", | |
| model=LLM_MODEL_ID, | |
| device_map="auto", | |
| dtype="auto", | |
| ) | |
| knowledge_base = get_all_game_data(embedding_model) | |
| def _select_content(title: str) -> list[dict]: | |
| """Helper to safely get the knowledge base for a specific title.""" | |
| return knowledge_base.get(title, []) | |
| def find_best_context(query: str, contents: list[dict], similarity_threshold: float) -> int: | |
| """Finds the most relevant document index based on semantic similarity.""" | |
| if not query or not contents: | |
| return -1 | |
| query_embedding = embedding_model.encode(query, prompt_name="query", convert_to_tensor=True).to(embedding_model.device) | |
| try: | |
| # Stack pre-computed tensors from our knowledge base | |
| contents_embeddings = torch.stack([item["embedding"] for item in contents]).to(embedding_model.device) | |
| except (RuntimeError, IndexError, TypeError) as e: | |
| print(f"Warning: Could not stack content embeddings. Error: {e}") | |
| return -1 | |
| # Compute cosine similarity between the 1 query embedding and N content embeddings | |
| similarities = util.pytorch_cos_sim(query_embedding, contents_embeddings) | |
| if similarities.numel() == 0: | |
| print("Warning: Similarity computation returned an empty tensor.") | |
| return -1 | |
| # Get the index and score of the top match | |
| best_index = similarities.argmax().item() | |
| best_score = similarities[0, best_index].item() | |
| print(f"Best score: {best_score:.4f} (Threshold: {similarity_threshold})") | |
| if best_score >= similarity_threshold: | |
| print(f"Using \"{contents[best_index]['metadata']['source']}\"...") | |
| return best_index | |
| print("No context met the similarity threshold.") | |
| return -1 | |
| def respond(message: str, history: list, title: str, chat_context: ChatContext): | |
| """Generates a streaming response from the LLM based on the best context found.""" | |
| default_threshold = chat_context.base_similarity | |
| followup_threshold = chat_context.followup_similarity | |
| contents = _select_content(title) | |
| if not contents: | |
| print(f"No content found for {title}") | |
| chat_context.context_index = -1 # Return -1 to reset context | |
| yield DEFAULT_MESSAGE_NO_MATCH, chat_context | |
| return | |
| if len(history) == 0: | |
| # Clear context on a new conversation | |
| print("New conversation started. Clearing context.") | |
| chat_context.context_index = -1 | |
| # Determine threshold: Use follow-up ONLY if we have a valid previous context. | |
| similarity_threshold = followup_threshold if chat_context.context_index != -1 else default_threshold | |
| print(f"Using {'follow-up' if chat_context.context_index != -1 else 'default'} threshold: {similarity_threshold}") | |
| # Find the best new context based on the current message | |
| found_context_index = find_best_context(message, contents, similarity_threshold) | |
| if found_context_index >= 0: | |
| chat_context.context_index = found_context_index # A new, relevant context was found and set | |
| elif chat_context.context_index >= 0: | |
| # PASS: A follow-up question, but no new context. Reuse the old one. | |
| print("No new context found, reusing previous context for follow-up.") | |
| else: | |
| # FAILURE: No new context was found AND no previous context exists. | |
| print("No context found and no previous context. Yielding no match.") | |
| yield DEFAULT_MESSAGE_NO_MATCH, chat_context | |
| return | |
| system_prompt = f"Answer the following QUESTION based only on the CONTEXT provided. If the answer cannot be found in the CONTEXT, write \"{DEFAULT_MESSAGE_NO_MATCH}\"\n---\nCONTEXT:\n{contents[chat_context.context_index]['text']}\n" | |
| user_prompt = f"QUESTION:\n{message}" | |
| messages = [{"role": "system", "content": system_prompt}] | |
| # Add previous turns (history) after the system prompt but before the current question | |
| messages.extend(history) | |
| messages.append({"role": "user", "content": user_prompt}) | |
| # Debug print the conversation being sent (excluding the large system prompt) | |
| for item in messages[1:]: | |
| print(f"[{item['role']}] {item['content']}") | |
| streamer = TextIteratorStreamer(llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| thread = Thread( | |
| target=llm_pipeline, | |
| kwargs=dict( | |
| text_inputs=messages, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| top_p=0.95, | |
| temperature=0.7, | |
| ) | |
| ) | |
| thread.start() | |
| response = "" | |
| for new_text in streamer: | |
| response += new_text | |
| # Yield the partial response AND the current state | |
| yield response, chat_context | |
| print(f"[assistant] {response}") | |
| # --- GRADIO UI --- | |
| # Defines the web interface for the chatbot. | |
| def on_title_changed(context_state: ChatContext) -> tuple[str, ChatContext]: | |
| """Resets the context display and state when the game is changed.""" | |
| context_state.context_index = -1 | |
| return """<div class="context">Context: None</div>""", context_state | |
| def on_sim_changed(context_state: ChatContext, base_sim: float, followup_sim: float) -> ChatContext: | |
| """Updates the similarity thresholds in the context state.""" | |
| context_state.base_similarity = base_sim | |
| context_state.followup_similarity = followup_sim | |
| return context_state | |
| gr.set_static_paths(paths=["assets/"]) | |
| with gr.Blocks(theme=silksong_theme, css=silksong_css) as demo: | |
| def on_context_changed(context_state: ChatContext, title: str) -> str: | |
| """Updates the HTML context display when the context_index state changes.""" | |
| context_index = context_state.context_index | |
| if context_index < 0: | |
| return """<div class="context">Context: None</div>""" | |
| contents = _select_content(title) | |
| if not contents or context_index >= len(contents): | |
| return """<div class="context">Context: Error</div>""" | |
| url = contents[context_index]['metadata']['source'] | |
| title = contents[context_index]['metadata']['title'] | |
| return f"""<div class="context">Context: <a href="{url}" target="_blank">{title}</a></div>""" | |
| gr.HTML(""" | |
| <div class="header-text"> | |
| <h1>A Weaver's Counsel</h1> | |
| <p>Speak, little traveler. What secrets of Pharloom do you seek?</p> | |
| <p style="font-style: italic;">(Note: This bot has a limited knowledge.)</p> | |
| </div> | |
| """) | |
| game_title = gr.Dropdown(["Hollow Knight", "Silksong"], label="Game", value="Silksong") | |
| output = gr.HTML("""<div class="context">Context: None</div>""") | |
| # Link the state object to the UI elements | |
| context_state = gr.State(ChatContext()) | |
| context_state.change(on_context_changed, [context_state, game_title], output) | |
| game_title.change(on_title_changed, context_state, [output, context_state]) | |
| gr.ChatInterface( | |
| respond, | |
| type="messages", | |
| chatbot=gr.Chatbot(type="messages", label=LLM_MODEL_ID), | |
| textbox=gr.Textbox(placeholder="Ask about the haunted kingdom...", container=False, submit_btn=True, scale=7), | |
| additional_inputs=[ | |
| game_title, | |
| context_state, ### Pass the state object as an input | |
| ], | |
| additional_outputs=[context_state], ### Receive the updated state as an output | |
| examples=[ | |
| ["Where can I find the Moorwing?", "Silksong"], | |
| ["Who is the voice of Lace?", "Silksong"], | |
| ["How can I beat the False Knight?", "Hollow Knight"], | |
| ["Any achievement for Hornet Protector?", "Hollow Knight"], | |
| ], | |
| cache_examples=False, | |
| ) | |
| base_sim = gr.Slider(minimum=0.1, maximum=1.0, value=BASE_SIMILARITY_THRESHOLD, step=0.1, label="Base Similarity Threshold") | |
| followup_sim = gr.Slider(minimum=0.1, maximum=1.0, value=FOLLOWUP_SIMILARITY_THRESHOLD, step=0.1, label="Similarity Threshold with follow-up questions (multi-turn)") | |
| base_sim.release(on_sim_changed, [context_state, base_sim, followup_sim], context_state) | |
| followup_sim.release(on_sim_changed, [context_state, base_sim, followup_sim], context_state) | |
| gr.HTML(""" | |
| <div class="disclaimer"> | |
| <p><strong>Disclaimer:</strong></p> | |
| <ul style="list-style: none; padding: 0;"> | |
| <li>This is a fan-made personal demonstration and not affiliated with any organization.<br>The bot is for entertainment purposes only.</li> | |
| <li>Factual information is sourced from the <a href="https://hollowknight.wiki" target="_blank">Hollow Knight Wiki</a>.<br>Content is available under <a href="https://creativecommons.org/licenses/by-sa/3.0/" target="_blank">Commons Attribution-ShareAlike</a> unless otherwise noted.</li> | |
| <li>Built by <a href="https://huggingface.co/bebechien" target="_blank">bebechien</a> with a 💖 for the world of Hollow Knight.</li> | |
| </ul> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| print("Launching Gradio demo...") | |
| demo.launch() | |