Spaces:
Sleeping
Sleeping
File size: 6,792 Bytes
8528063 3933487 6c968b2 ca93545 b00d97e 3933487 ca93545 6c968b2 3933487 b00d97e 6c968b2 b00d97e fb0adfa b00d97e 6c968b2 3933487 ca93545 3933487 6c968b2 ca93545 6c968b2 fb3e333 6c968b2 77e115a 6c968b2 3933487 ca93545 6c968b2 fb3e333 3933487 6c968b2 b00d97e 6c968b2 b00d97e 6c968b2 ca93545 6c968b2 b00d97e 3933487 6c968b2 fb3e333 6c968b2 fb3e333 6c968b2 3933487 6c968b2 fb3e333 6c968b2 3933487 6c968b2 3933487 6c968b2 3933487 77e115a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import gradio as gr
from huggingface_hub import InferenceClient
from typing import List, Tuple
import logging
from collections import deque
import re
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Initialize the InferenceClient with API token
try:
client = InferenceClient(
model="meta-llama/Llama-2-7b-chat-hf", # Updated to the requested model
token=os.getenv("HUGGINGFACEHUB_API_TOKEN")
)
logger.info("Successfully initialized InferenceClient")
except Exception as e:
logger.error(f"Failed to initialize InferenceClient: {str(e)}")
raise
# Memory storage for learning from past queries
MEMORY = deque(maxlen=100) # Store up to 100 query-response pairs
def add_to_memory(query: str, response: str):
"""Add a query-response pair to memory."""
MEMORY.append({"query": query, "response": response})
logger.info("Added query-response pair to memory")
def find_relevant_context(query: str, max_contexts: int = 2) -> str:
"""Retrieve relevant past queries and responses based on simple keyword matching."""
query_words = set(re.findall(r'\w+', query.lower()))
relevant = []
for mem in MEMORY:
mem_words = set(re.findall(r'\w+', mem["query"].lower()))
overlap = len(query_words & mem_words) / max(len(query_words), 1)
if overlap > 0.3: # Threshold for relevance
relevant.append(mem)
if len(relevant) >= max_contexts:
break
if relevant:
context = "\n".join(
[f"Past Query: {mem['query']}\nPast Response: {mem['response']}" for mem in relevant]
)
return f"Relevant past interactions:\n{context}\n\n"
return ""
def respond(
message: str,
history: List[Tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
) -> str:
"""
Generates an educational response using past interactions for context.
Args:
message (str): The student's input question or query.
history (List[Tuple[str, str]]): Chat history with student and AI teacher messages.
system_message (str): The system prompt defining the AI teacher's behavior.
max_tokens (int): Maximum number of tokens to generate.
temperature (float): Controls randomness in response generation.
top_p (float): Controls diversity via nucleus sampling.
Yields:
str: The AI teacher's response, streamed token by token.
"""
# Validate input parameters
if not message.strip():
raise ValueError("Input message cannot be empty")
if max_tokens < 1 or max_tokens > 2048:
raise ValueError("max_tokens must be between 1 and 2048")
if temperature < 0.1 or temperature > 2.0:
raise ValueError("temperature must be between 0.1 and 2.0")
if top_p < 0.1 or top_p > 1.0:
raise ValueError("top_p must be between 0.1 and 1.0")
# Retrieve relevant past interactions
context = find_relevant_context(message)
# Construct the message history with memory context
messages = [
{
"role": "system",
"content": system_message + "\n\nUse the following past interactions to inform your response if relevant:\n" + context,
}
]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = ""
try:
stream = client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
)
for message in stream:
token = message.choices[0].delta.content or ""
response += token
yield response
# Store the query and final response in memory
add_to_memory(message, response)
except Exception as e:
error_msg = f"Error during chat completion: {str(e)}"
logger.error(error_msg)
yield error_msg # Yield the error message to display in Gradio
def main():
"""
Sets up and launches the Gradio ChatInterface for the AI Teacher chatbot.
"""
default_system_message = (
"You are an AI Teacher, a knowledgeable and patient educator dedicated to helping students and learners. "
"Your goal is to explain concepts clearly, provide step-by-step guidance, and encourage critical thinking. "
"Adapt your explanations to the learner's level, ask follow-up questions to deepen understanding, and provide examples where helpful. "
"Be supportive, professional, and engaging in all interactions."
)
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(
value=default_system_message,
label="AI Teacher Prompt",
lines=3,
placeholder="Customize the AI Teacher's teaching style or instructions",
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Maximum Response Length",
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Response Creativity",
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Response Diversity",
),
],
title="AI Teacher: Your Study Companion",
description=(
"Welcome to AI Teacher, your personal guide for learning and studying! "
"Ask questions about any subject, and I'll provide clear explanations, examples, and tips to help you succeed. "
"Adjust the settings to customize how I respond to your questions."
),
theme="soft",
css="""
.gradio-container { max-width: 900px; margin: auto; padding: 20px; }
.chatbot { border-radius: 12px; background-color: #f9fafb; }
h1 { color: #2b6cb0; }
.message { font-size: 16px; }
""",
)
try:
logger.info("Launching Gradio interface for AI Teacher")
demo.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
logger.error(f"Failed to launch Gradio interface: {str(e)}")
raise
if __name__ == "__main__":
main() |