Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from typing import List, Tuple | |
import logging | |
from collections import deque | |
import re | |
import os | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize the InferenceClient with API token | |
try: | |
client = InferenceClient( | |
model="meta-llama/Llama-2-7b-chat-hf", # Updated to the requested model | |
token=os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
) | |
logger.info("Successfully initialized InferenceClient") | |
except Exception as e: | |
logger.error(f"Failed to initialize InferenceClient: {str(e)}") | |
raise | |
# Memory storage for learning from past queries | |
MEMORY = deque(maxlen=100) # Store up to 100 query-response pairs | |
def add_to_memory(query: str, response: str): | |
"""Add a query-response pair to memory.""" | |
MEMORY.append({"query": query, "response": response}) | |
logger.info("Added query-response pair to memory") | |
def find_relevant_context(query: str, max_contexts: int = 2) -> str: | |
"""Retrieve relevant past queries and responses based on simple keyword matching.""" | |
query_words = set(re.findall(r'\w+', query.lower())) | |
relevant = [] | |
for mem in MEMORY: | |
mem_words = set(re.findall(r'\w+', mem["query"].lower())) | |
overlap = len(query_words & mem_words) / max(len(query_words), 1) | |
if overlap > 0.3: # Threshold for relevance | |
relevant.append(mem) | |
if len(relevant) >= max_contexts: | |
break | |
if relevant: | |
context = "\n".join( | |
[f"Past Query: {mem['query']}\nPast Response: {mem['response']}" for mem in relevant] | |
) | |
return f"Relevant past interactions:\n{context}\n\n" | |
return "" | |
def respond( | |
message: str, | |
history: List[Tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
) -> str: | |
""" | |
Generates an educational response using past interactions for context. | |
Args: | |
message (str): The student's input question or query. | |
history (List[Tuple[str, str]]): Chat history with student and AI teacher messages. | |
system_message (str): The system prompt defining the AI teacher's behavior. | |
max_tokens (int): Maximum number of tokens to generate. | |
temperature (float): Controls randomness in response generation. | |
top_p (float): Controls diversity via nucleus sampling. | |
Yields: | |
str: The AI teacher's response, streamed token by token. | |
""" | |
# Validate input parameters | |
if not message.strip(): | |
raise ValueError("Input message cannot be empty") | |
if max_tokens < 1 or max_tokens > 2048: | |
raise ValueError("max_tokens must be between 1 and 2048") | |
if temperature < 0.1 or temperature > 2.0: | |
raise ValueError("temperature must be between 0.1 and 2.0") | |
if top_p < 0.1 or top_p > 1.0: | |
raise ValueError("top_p must be between 0.1 and 1.0") | |
# Retrieve relevant past interactions | |
context = find_relevant_context(message) | |
# Construct the message history with memory context | |
messages = [ | |
{ | |
"role": "system", | |
"content": system_message + "\n\nUse the following past interactions to inform your response if relevant:\n" + context, | |
} | |
] | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
try: | |
stream = client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
for message in stream: | |
token = message.choices[0].delta.content or "" | |
response += token | |
yield response | |
# Store the query and final response in memory | |
add_to_memory(message, response) | |
except Exception as e: | |
error_msg = f"Error during chat completion: {str(e)}" | |
logger.error(error_msg) | |
yield error_msg # Yield the error message to display in Gradio | |
def main(): | |
""" | |
Sets up and launches the Gradio ChatInterface for the AI Teacher chatbot. | |
""" | |
default_system_message = ( | |
"You are an AI Teacher, a knowledgeable and patient educator dedicated to helping students and learners. " | |
"Your goal is to explain concepts clearly, provide step-by-step guidance, and encourage critical thinking. " | |
"Adapt your explanations to the learner's level, ask follow-up questions to deepen understanding, and provide examples where helpful. " | |
"Be supportive, professional, and engaging in all interactions." | |
) | |
demo = gr.ChatInterface( | |
fn=respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value=default_system_message, | |
label="AI Teacher Prompt", | |
lines=3, | |
placeholder="Customize the AI Teacher's teaching style or instructions", | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=512, | |
step=1, | |
label="Maximum Response Length", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Response Creativity", | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Response Diversity", | |
), | |
], | |
title="AI Teacher: Your Study Companion", | |
description=( | |
"Welcome to AI Teacher, your personal guide for learning and studying! " | |
"Ask questions about any subject, and I'll provide clear explanations, examples, and tips to help you succeed. " | |
"Adjust the settings to customize how I respond to your questions." | |
), | |
theme="soft", | |
css=""" | |
.gradio-container { max-width: 900px; margin: auto; padding: 20px; } | |
.chatbot { border-radius: 12px; background-color: #f9fafb; } | |
h1 { color: #2b6cb0; } | |
.message { font-size: 16px; } | |
""", | |
) | |
try: | |
logger.info("Launching Gradio interface for AI Teacher") | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |
except Exception as e: | |
logger.error(f"Failed to launch Gradio interface: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
main() |