File size: 10,358 Bytes
03797ca
 
 
 
 
f562768
eb9ecb0
8dfd5bc
d8b5900
8dfd5bc
03797ca
8dfd5bc
 
 
03797ca
 
30f0e84
03797ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11b62a8
 
03797ca
 
 
 
 
 
0b843b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03797ca
 
 
1e7afb0
03797ca
 
 
 
1e7afb0
 
03797ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b843b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03797ca
 
0b843b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03797ca
 
84e5a6b
ddfb119
9d2f962
eb9ecb0
ddfb119
eb9ecb0
ddfb119
 
 
 
9d2f962
ddfb119
 
 
 
 
 
 
 
9d2f962
ddfb119
 
 
 
 
 
 
 
 
9d2f962
ddfb119
 
 
03797ca
 
ddfb119
03797ca
 
 
 
 
 
 
 
 
 
 
 
ddfb119
9068693
03797ca
 
 
 
 
 
 
 
 
 
0b843b7
 
021302a
 
 
0b843b7
 
03797ca
021302a
0b843b7
 
 
 
 
 
 
03797ca
0b843b7
021302a
 
ec507b2
021302a
 
03797ca
0b843b7
03797ca
 
 
0b843b7
021302a
aa83b8f
 
16bc77f
 
 
 
03797ca
021302a
 
 
 
 
03797ca
 
0b843b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import os
import urllib
import requests
from bs4 import BeautifulSoup
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import feedparser

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Define device and load model and tokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"

# Load model and tokenizer
try:
    logger.debug("Attempting to load the model and tokenizer")
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    logger.debug("Model and tokenizer loaded successfully")
except Exception as e:
    logger.error(f"Error loading model and tokenizer: {e}")
    model = None
    tokenizer = None

# Function to fetch news from Google News RSS feed
def fetch_news(term, num_results=2):
    logger.debug(f"Fetching news for term: {term}")
    encoded_term = urllib.parse.quote(term)
    url = f"https://news.google.com/rss/search?q={encoded_term}"
    feed = feedparser.parse(url)
    results = []
    for entry in feed.entries[:num_results]:
        results.append({"link": entry.link, "text": entry.title})
    logger.debug(f"Fetched news results: {results}")
    return results

# Function to perform a Google search and return the results
def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
    logger.debug(f"Starting search for term: {term}")
    escaped_term = urllib.parse.quote_plus(term)
    start = 0
    all_results = []
    max_chars_per_page = 8000

    with requests.Session() as session:
        while start < num_results:
            try:
                resp = session.get(
                    url="https://www.google.com/search",
                    headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
                    params={
                        "q": term,
                        "num": num_results - start,
                        "hl": lang,
                        "start": start,
                        "safe": safe,
                    },
                    timeout=timeout,
                    verify=ssl_verify,
                )
                resp.raise_for_status()
                soup = BeautifulSoup(resp.text, "html.parser")
                result_block = soup.find_all("div", attrs={"class": "g"})
                if not result_block:
                    start += 1
                    continue
                for result in result_block:
                    link = result.find("a", href=True)
                    if link:
                        link = link["href"]
                        try:
                            webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"})
                            webpage.raise_for_status()
                            visible_text = extract_text_from_webpage(webpage.text)
                            if len(visible_text) > max_chars_per_page:
                                visible_text = visible_text[:max_chars_per_page] + "..."
                            all_results.append({"link": link, "text": visible_text})
                        except requests.exceptions.RequestException as e:
                            logger.error(f"Error fetching or processing {link}: {e}")
                            all_results.append({"link": link, "text": None})
                    else:
                        all_results.append({"link": None, "text": None})
                start += len(result_block)
            except Exception as e:
                logger.error(f"Error during search: {e}")
                break
    logger.debug(f"Search results: {all_results}")
    return all_results

# Function to extract visible text from HTML content
def extract_text_from_webpage(html_content):
    soup = BeautifulSoup(html_content, "html.parser")
    for tag in soup(["script", "style", "header", "footer", "nav"]):
        tag.extract()
    visible_text = soup.get_text(strip=True)
    return visible_text

# Function to format the prompt for the language model
def format_prompt(user_prompt, chat_history):
    logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}")
    prompt = ""
    for item in chat_history:
        prompt += f"User: {item[0]}\nAssistant: {item[1]}\n"
    prompt += f"User: {user_prompt}\nAssistant:"
    logger.debug(f"Formatted prompt: {prompt}")
    return prompt

# Function for model inference
def model_inference(
        user_prompt,
        chat_history,
        web_search,
        temperature,
        max_new_tokens,
        repetition_penalty,
        top_p,
        tokenizer  # Pass tokenizer as an argument
):
    logger.debug(f"Starting model inference with user prompt: {user_prompt}, chat history: {chat_history}, web_search: {web_search}")
    if not isinstance(user_prompt, dict):
        logger.error("Invalid input format. Expected a dictionary.")
        return "Invalid input format. Expected a dictionary."

    if "files" not in user_prompt:
        user_prompt["files"] = []

    if not user_prompt["files"]:
        if web_search:
            logger.debug("Performing news search")
            news_results = fetch_news(user_prompt["text"])
            news2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in news_results])
            formatted_prompt = format_prompt(f"{user_prompt['text']} [NEWS] {news2}", chat_history)
            inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
            if model:
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    repetition_penalty=repetition_penalty,
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p
                )
                response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            else:
                response = "Model is not available. Please try again later."
            logger.debug(f"Model response: {response}")
            return response
        else:
            formatted_prompt = format_prompt(user_prompt["text"], chat_history)
            inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
            if model:
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    repetition_penalty=repetition_penalty,
                    do_sample=True,
                    temperature=temperature,
                    top_p=top_p
                )
                response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            else:
                response = "Model is not available. Please try again later."
            logger.debug(f"Model response: {response}")
            return response
    else:
        return "Image input not supported in this implementation."

# Define Gradio interface components
max_new_tokens = gr.Slider(
    minimum=1,
    maximum=16000,
    value=2048,
    step=64,
    interactive=True,
    label="Maximum number of new tokens to generate",
)
repetition_penalty = gr.Slider(
    minimum=0.01,
    maximum=5.0,
    value=1,
    step=0.01,
    interactive=True,
    label="Repetition penalty",
    info="1.0 is equivalent to no penalty",
)
decoding_strategy = gr.Radio(
    [
        "Greedy",
        "Top P Sampling",
    ],
    value="Top P Sampling",
    label="Decoding strategy",
    interactive=True,
    info="Higher values are equivalent to sampling more low-probability tokens.",
)
temperature = gr.Slider(
    minimum=0.0,
    maximum=2.0,
    value=0.5,
    step=0.05,
    visible=True,
    interactive=True,
    label="Sampling temperature",
    info="Higher values will produce more diverse outputs.",
)
top_p = gr.Slider(
    minimum=0.01,
    maximum=0.99,
    value=0.9,
    step=0.01,
    visible=True,
    interactive=True,
    label="Top P",
    info="Higher values are equivalent to sampling more low-probability tokens.",
)

# Create a chatbot interface
chatbot = gr.Chatbot(
    label="OpenGPT-4o-Chatty",
    show_copy_button=True,
    likeable=True,
    layout="panel"
)

# Define Gradio interface
def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
    # Ensure the tokenizer is accessible within the function scope
    global tokenizer

    # Wrap the user input in a dictionary as expected by the model_inference function
    user_prompt = {"text": user_input, "files": []}
    
    # Perform model inference
    response = model_inference(
        user_prompt=user_prompt,
        chat_history=history,
        web_search=web_search,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        repetition_penalty=repetition_penalty,
        top_p=top_p,
        tokenizer=tokenizer  # Pass tokenizer to the model_inference function
    )
    
    # Update history with the user input and model response
    history.append((user_input, response))
    
    # Return the response and updated history
    return response, history

# Define the Gradio interface components
interface = gr.Interface(
    fn=chat_interface,
    inputs=[
        gr.Textbox(label="User Input", placeholder="Type your message here..."),
        gr.State([]),  # Initialize the chat history as an empty list
        gr.Checkbox(label="Perform Web Search"),
        gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"),
        gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5),
        gr.Slider(minimum=1, maximum=16000, step=64, label="Maximum number of new tokens to generate", value=2048),
        gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1),
        gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9)
    ],
    outputs=[
        gr.Textbox(label="Assistant Response"),
        gr.State([])  # Update the chat history
    ],
    live=True
)

# Launch the Gradio interface
interface.launch()