File size: 6,792 Bytes
8528063
3933487
6c968b2
 
ca93545
 
b00d97e
3933487
ca93545
6c968b2
 
 
 
 
3933487
b00d97e
6c968b2
b00d97e
fb0adfa
b00d97e
 
6c968b2
 
 
 
3933487
ca93545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3933487
6c968b2
 
 
 
 
 
 
 
ca93545
6c968b2
 
 
 
 
 
 
 
 
 
 
fb3e333
 
6c968b2
 
77e115a
 
6c968b2
 
3933487
ca93545
 
 
 
 
 
 
 
 
 
6c968b2
 
 
 
 
fb3e333
3933487
6c968b2
 
b00d97e
6c968b2
 
 
 
 
b00d97e
 
6c968b2
 
 
ca93545
 
6c968b2
b00d97e
 
 
3933487
6c968b2
 
fb3e333
6c968b2
 
 
 
 
fb3e333
6c968b2
3933487
6c968b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb3e333
6c968b2
3933487
6c968b2
 
 
 
 
 
 
 
3933487
6c968b2
 
 
 
 
 
3933487
 
77e115a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
from huggingface_hub import InferenceClient
from typing import List, Tuple
import logging
from collections import deque
import re
import os

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# Initialize the InferenceClient with API token
try:
    client = InferenceClient(
        model="meta-llama/Llama-2-7b-chat-hf",  # Updated to the requested model
        token=os.getenv("HUGGINGFACEHUB_API_TOKEN")
    )
    logger.info("Successfully initialized InferenceClient")
except Exception as e:
    logger.error(f"Failed to initialize InferenceClient: {str(e)}")
    raise

# Memory storage for learning from past queries
MEMORY = deque(maxlen=100)  # Store up to 100 query-response pairs

def add_to_memory(query: str, response: str):
    """Add a query-response pair to memory."""
    MEMORY.append({"query": query, "response": response})
    logger.info("Added query-response pair to memory")

def find_relevant_context(query: str, max_contexts: int = 2) -> str:
    """Retrieve relevant past queries and responses based on simple keyword matching."""
    query_words = set(re.findall(r'\w+', query.lower()))
    relevant = []
    
    for mem in MEMORY:
        mem_words = set(re.findall(r'\w+', mem["query"].lower()))
        overlap = len(query_words & mem_words) / max(len(query_words), 1)
        if overlap > 0.3:  # Threshold for relevance
            relevant.append(mem)
        if len(relevant) >= max_contexts:
            break
    
    if relevant:
        context = "\n".join(
            [f"Past Query: {mem['query']}\nPast Response: {mem['response']}" for mem in relevant]
        )
        return f"Relevant past interactions:\n{context}\n\n"
    return ""

def respond(
    message: str,
    history: List[Tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
) -> str:
    """
    Generates an educational response using past interactions for context.
    Args:
        message (str): The student's input question or query.
        history (List[Tuple[str, str]]): Chat history with student and AI teacher messages.
        system_message (str): The system prompt defining the AI teacher's behavior.
        max_tokens (int): Maximum number of tokens to generate.
        temperature (float): Controls randomness in response generation.
        top_p (float): Controls diversity via nucleus sampling.
    Yields:
        str: The AI teacher's response, streamed token by token.
    """
    # Validate input parameters
    if not message.strip():
        raise ValueError("Input message cannot be empty")
    if max_tokens < 1 or max_tokens > 2048:
        raise ValueError("max_tokens must be between 1 and 2048")
    if temperature < 0.1 or temperature > 2.0:
        raise ValueError("temperature must be between 0.1 and 2.0")
    if top_p < 0.1 or top_p > 1.0:
        raise ValueError("top_p must be between 0.1 and 1.0")

    # Retrieve relevant past interactions
    context = find_relevant_context(message)
    
    # Construct the message history with memory context
    messages = [
        {
            "role": "system",
            "content": system_message + "\n\nUse the following past interactions to inform your response if relevant:\n" + context,
        }
    ]
    for user_msg, assistant_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages.append({"role": "assistant", "content": assistant_msg})
    messages.append({"role": "user", "content": message})

    response = ""
    try:
        stream = client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        )
        for message in stream:
            token = message.choices[0].delta.content or ""
            response += token
            yield response
        # Store the query and final response in memory
        add_to_memory(message, response)
    except Exception as e:
        error_msg = f"Error during chat completion: {str(e)}"
        logger.error(error_msg)
        yield error_msg  # Yield the error message to display in Gradio

def main():
    """
    Sets up and launches the Gradio ChatInterface for the AI Teacher chatbot.
    """
    default_system_message = (
        "You are an AI Teacher, a knowledgeable and patient educator dedicated to helping students and learners. "
        "Your goal is to explain concepts clearly, provide step-by-step guidance, and encourage critical thinking. "
        "Adapt your explanations to the learner's level, ask follow-up questions to deepen understanding, and provide examples where helpful. "
        "Be supportive, professional, and engaging in all interactions."
    )

    demo = gr.ChatInterface(
        fn=respond,
        additional_inputs=[
            gr.Textbox(
                value=default_system_message,
                label="AI Teacher Prompt",
                lines=3,
                placeholder="Customize the AI Teacher's teaching style or instructions",
            ),
            gr.Slider(
                minimum=1,
                maximum=2048,
                value=512,
                step=1,
                label="Maximum Response Length",
            ),
            gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=0.7,
                step=0.1,
                label="Response Creativity",
            ),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Response Diversity",
            ),
        ],
        title="AI Teacher: Your Study Companion",
        description=(
            "Welcome to AI Teacher, your personal guide for learning and studying! "
            "Ask questions about any subject, and I'll provide clear explanations, examples, and tips to help you succeed. "
            "Adjust the settings to customize how I respond to your questions."
        ),
        theme="soft",
        css="""
            .gradio-container { max-width: 900px; margin: auto; padding: 20px; }
            .chatbot { border-radius: 12px; background-color: #f9fafb; }
            h1 { color: #2b6cb0; }
            .message { font-size: 16px; }
        """,
    )

    try:
        logger.info("Launching Gradio interface for AI Teacher")
        demo.launch(server_name="0.0.0.0", server_port=7860)
    except Exception as e:
        logger.error(f"Failed to launch Gradio interface: {str(e)}")
        raise

if __name__ == "__main__":
    main()