AI-Teacher1 / app.py
Intellectualtech's picture
Update app.py
fb0adfa verified
import gradio as gr
from huggingface_hub import InferenceClient
from typing import List, Tuple
import logging
from collections import deque
import re
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Initialize the InferenceClient with API token
try:
client = InferenceClient(
model="meta-llama/Llama-2-7b-chat-hf", # Updated to the requested model
token=os.getenv("HUGGINGFACEHUB_API_TOKEN")
)
logger.info("Successfully initialized InferenceClient")
except Exception as e:
logger.error(f"Failed to initialize InferenceClient: {str(e)}")
raise
# Memory storage for learning from past queries
MEMORY = deque(maxlen=100) # Store up to 100 query-response pairs
def add_to_memory(query: str, response: str):
"""Add a query-response pair to memory."""
MEMORY.append({"query": query, "response": response})
logger.info("Added query-response pair to memory")
def find_relevant_context(query: str, max_contexts: int = 2) -> str:
"""Retrieve relevant past queries and responses based on simple keyword matching."""
query_words = set(re.findall(r'\w+', query.lower()))
relevant = []
for mem in MEMORY:
mem_words = set(re.findall(r'\w+', mem["query"].lower()))
overlap = len(query_words & mem_words) / max(len(query_words), 1)
if overlap > 0.3: # Threshold for relevance
relevant.append(mem)
if len(relevant) >= max_contexts:
break
if relevant:
context = "\n".join(
[f"Past Query: {mem['query']}\nPast Response: {mem['response']}" for mem in relevant]
)
return f"Relevant past interactions:\n{context}\n\n"
return ""
def respond(
message: str,
history: List[Tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
) -> str:
"""
Generates an educational response using past interactions for context.
Args:
message (str): The student's input question or query.
history (List[Tuple[str, str]]): Chat history with student and AI teacher messages.
system_message (str): The system prompt defining the AI teacher's behavior.
max_tokens (int): Maximum number of tokens to generate.
temperature (float): Controls randomness in response generation.
top_p (float): Controls diversity via nucleus sampling.
Yields:
str: The AI teacher's response, streamed token by token.
"""
# Validate input parameters
if not message.strip():
raise ValueError("Input message cannot be empty")
if max_tokens < 1 or max_tokens > 2048:
raise ValueError("max_tokens must be between 1 and 2048")
if temperature < 0.1 or temperature > 2.0:
raise ValueError("temperature must be between 0.1 and 2.0")
if top_p < 0.1 or top_p > 1.0:
raise ValueError("top_p must be between 0.1 and 1.0")
# Retrieve relevant past interactions
context = find_relevant_context(message)
# Construct the message history with memory context
messages = [
{
"role": "system",
"content": system_message + "\n\nUse the following past interactions to inform your response if relevant:\n" + context,
}
]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = ""
try:
stream = client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
)
for message in stream:
token = message.choices[0].delta.content or ""
response += token
yield response
# Store the query and final response in memory
add_to_memory(message, response)
except Exception as e:
error_msg = f"Error during chat completion: {str(e)}"
logger.error(error_msg)
yield error_msg # Yield the error message to display in Gradio
def main():
"""
Sets up and launches the Gradio ChatInterface for the AI Teacher chatbot.
"""
default_system_message = (
"You are an AI Teacher, a knowledgeable and patient educator dedicated to helping students and learners. "
"Your goal is to explain concepts clearly, provide step-by-step guidance, and encourage critical thinking. "
"Adapt your explanations to the learner's level, ask follow-up questions to deepen understanding, and provide examples where helpful. "
"Be supportive, professional, and engaging in all interactions."
)
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(
value=default_system_message,
label="AI Teacher Prompt",
lines=3,
placeholder="Customize the AI Teacher's teaching style or instructions",
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Maximum Response Length",
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Response Creativity",
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Response Diversity",
),
],
title="AI Teacher: Your Study Companion",
description=(
"Welcome to AI Teacher, your personal guide for learning and studying! "
"Ask questions about any subject, and I'll provide clear explanations, examples, and tips to help you succeed. "
"Adjust the settings to customize how I respond to your questions."
),
theme="soft",
css="""
.gradio-container { max-width: 900px; margin: auto; padding: 20px; }
.chatbot { border-radius: 12px; background-color: #f9fafb; }
h1 { color: #2b6cb0; }
.message { font-size: 16px; }
""",
)
try:
logger.info("Launching Gradio interface for AI Teacher")
demo.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
logger.error(f"Failed to launch Gradio interface: {str(e)}")
raise
if __name__ == "__main__":
main()