Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,21 +1,52 @@
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
import urllib
|
3 |
import requests
|
|
|
|
|
4 |
from typing import List, Dict, Union
|
|
|
|
|
5 |
import torch
|
6 |
import gradio as gr
|
7 |
-
from
|
8 |
from huggingface_hub import InferenceClient
|
9 |
-
from functools import lru_cache
|
10 |
-
import logging
|
11 |
-
|
12 |
-
# Set up logging
|
13 |
-
logging.basicConfig(level=logging.DEBUG)
|
14 |
|
15 |
-
#
|
16 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
#
|
|
|
|
|
|
|
|
|
19 |
@lru_cache(maxsize=128)
|
20 |
def extract_text_from_webpage(html_content):
|
21 |
soup = BeautifulSoup(html_content, "html.parser")
|
@@ -24,120 +55,199 @@ def extract_text_from_webpage(html_content):
|
|
24 |
visible_text = soup.get_text(strip=True)
|
25 |
return visible_text
|
26 |
|
27 |
-
#
|
28 |
def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
|
29 |
escaped_term = urllib.parse.quote_plus(term)
|
30 |
start = 0
|
31 |
all_results = []
|
32 |
-
max_chars_per_page = 8000
|
33 |
|
34 |
with requests.Session() as session:
|
35 |
while start < num_results:
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
all_results.append({"link":
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
break
|
73 |
-
logging.debug(f"Web search results: {all_results}")
|
74 |
return all_results
|
75 |
|
76 |
-
#
|
77 |
def format_prompt(user_prompt, chat_history):
|
78 |
prompt = "<s>"
|
79 |
for item in chat_history:
|
80 |
if isinstance(item, tuple):
|
81 |
-
prompt += f"[INST] {item[0]} [/INST]"
|
82 |
-
prompt += f" {item[1]}</s>"
|
83 |
else:
|
84 |
prompt += f" [Image] "
|
85 |
prompt += f"[INST] {user_prompt} [/INST]"
|
86 |
return prompt
|
87 |
|
88 |
-
#
|
89 |
-
def
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res['text']])
|
104 |
-
logging.debug(f"Formatted web search results: {web2}")
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
generate_kwargs = dict(max_new_tokens=5000, do_sample=True)
|
121 |
-
formatted_prompt = format_prompt(f"""You are OpenGPT 4o... [USER] {prompt} [OpenGPT 4o]""", [(prompt, )])
|
122 |
-
logging.debug(f"Formatted prompt without web search: {formatted_prompt}")
|
123 |
-
|
124 |
-
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
125 |
-
output = ""
|
126 |
-
for response in stream:
|
127 |
-
if not response.token.text == "</s>":
|
128 |
-
output += response.token.text
|
129 |
-
yield output
|
130 |
|
131 |
-
# Create
|
132 |
-
|
133 |
-
fn=
|
134 |
inputs=[
|
135 |
-
gr.Textbox(label="User
|
136 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
],
|
138 |
-
|
139 |
-
|
140 |
)
|
141 |
|
142 |
-
|
143 |
-
|
|
|
1 |
+
# Import necessary libraries
|
2 |
import os
|
3 |
+
import time
|
4 |
+
import copy
|
5 |
import urllib
|
6 |
import requests
|
7 |
+
import random
|
8 |
+
from threading import Thread
|
9 |
from typing import List, Dict, Union
|
10 |
+
from functools import lru_cache
|
11 |
+
from bs4 import BeautifulSoup
|
12 |
import torch
|
13 |
import gradio as gr
|
14 |
+
from transformers import TextIteratorStreamer, AutoModelForSeq2SeqLM, AutoTokenizer
|
15 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
# Define device and load model and tokenizer
|
18 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
|
20 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
22 |
+
|
23 |
+
# Set system prompt
|
24 |
+
SYSTEM_PROMPT = [
|
25 |
+
{
|
26 |
+
"role": "system",
|
27 |
+
"content": [
|
28 |
+
{
|
29 |
+
"type": "text",
|
30 |
+
"text": """You are OpenGPT 4o, an exceptionally capable and versatile AI assistant. Designed to assist human users through insightful conversations, your key attributes include intelligence and knowledge, image generation and perception, and providing reliable information. Always ensure a seamless and enjoyable experience for the user.""",
|
31 |
+
},
|
32 |
+
],
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"role": "assistant",
|
36 |
+
"content": [
|
37 |
+
{
|
38 |
+
"type": "text",
|
39 |
+
"text": "Hello, I'm OpenGPT 4o. How can I help you today?",
|
40 |
+
},
|
41 |
+
],
|
42 |
+
}
|
43 |
+
]
|
44 |
|
45 |
+
# Function to check if a turn in the chat history only contains media
|
46 |
+
def turn_is_pure_media(turn):
|
47 |
+
return turn[1] is None
|
48 |
+
|
49 |
+
# Function to extract visible text from HTML content
|
50 |
@lru_cache(maxsize=128)
|
51 |
def extract_text_from_webpage(html_content):
|
52 |
soup = BeautifulSoup(html_content, "html.parser")
|
|
|
55 |
visible_text = soup.get_text(strip=True)
|
56 |
return visible_text
|
57 |
|
58 |
+
# Function to perform a Google search and return the results
|
59 |
def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
|
60 |
escaped_term = urllib.parse.quote_plus(term)
|
61 |
start = 0
|
62 |
all_results = []
|
63 |
+
max_chars_per_page = 8000
|
64 |
|
65 |
with requests.Session() as session:
|
66 |
while start < num_results:
|
67 |
+
resp = session.get(
|
68 |
+
url="https://www.google.com/search",
|
69 |
+
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
|
70 |
+
params={
|
71 |
+
"q": term,
|
72 |
+
"num": num_results - start,
|
73 |
+
"hl": lang,
|
74 |
+
"start": start,
|
75 |
+
"safe": safe,
|
76 |
+
},
|
77 |
+
timeout=timeout,
|
78 |
+
verify=ssl_verify,
|
79 |
+
)
|
80 |
+
resp.raise_for_status()
|
81 |
+
soup = BeautifulSoup(resp.text, "html.parser")
|
82 |
+
result_block = soup.find_all("div", attrs={"class": "g"})
|
83 |
+
if not result_block:
|
84 |
+
start += 1
|
85 |
+
continue
|
86 |
+
for result in result_block:
|
87 |
+
link = result.find("a", href=True)
|
88 |
+
if link:
|
89 |
+
link = link["href"]
|
90 |
+
try:
|
91 |
+
webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"})
|
92 |
+
webpage.raise_for_status()
|
93 |
+
visible_text = extract_text_from_webpage(webpage.text)
|
94 |
+
if len(visible_text) > max_chars_per_page:
|
95 |
+
visible_text = visible_text[:max_chars_per_page] + "..."
|
96 |
+
all_results.append({"link": link, "text": visible_text})
|
97 |
+
except requests.exceptions.RequestException as e:
|
98 |
+
print(f"Error fetching or processing {link}: {e}")
|
99 |
+
all_results.append({"link": link, "text": None})
|
100 |
+
else:
|
101 |
+
all_results.append({"link": None, "text": None})
|
102 |
+
start += len(result_block)
|
|
|
|
|
103 |
return all_results
|
104 |
|
105 |
+
# Function to format the prompt for the language model
|
106 |
def format_prompt(user_prompt, chat_history):
|
107 |
prompt = "<s>"
|
108 |
for item in chat_history:
|
109 |
if isinstance(item, tuple):
|
110 |
+
prompt += f"[INST] {item[0]} [/INST] {item[1]}</s>"
|
|
|
111 |
else:
|
112 |
prompt += f" [Image] "
|
113 |
prompt += f"[INST] {user_prompt} [/INST]"
|
114 |
return prompt
|
115 |
|
116 |
+
# Function for model inference
|
117 |
+
def model_inference(
|
118 |
+
user_prompt,
|
119 |
+
chat_history,
|
120 |
+
web_search,
|
121 |
+
decoding_strategy,
|
122 |
+
temperature,
|
123 |
+
max_new_tokens,
|
124 |
+
repetition_penalty,
|
125 |
+
top_p,
|
126 |
+
):
|
127 |
+
if not user_prompt["files"]:
|
128 |
+
if web_search:
|
129 |
+
web_results = search(user_prompt["text"])
|
130 |
+
web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
|
131 |
+
formatted_prompt = format_prompt(f"{user_prompt['text']} [WEB] {web2}", chat_history)
|
132 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
|
133 |
+
outputs = model.generate(
|
134 |
+
**inputs,
|
135 |
+
max_new_tokens=max_new_tokens,
|
136 |
+
repetition_penalty=repetition_penalty,
|
137 |
+
do_sample=True,
|
138 |
+
temperature=temperature,
|
139 |
+
top_p=top_p
|
140 |
+
)
|
141 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
142 |
+
return response
|
143 |
+
else:
|
144 |
+
formatted_prompt = format_prompt(user_prompt["text"], chat_history)
|
145 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
|
146 |
+
outputs = model.generate(
|
147 |
+
**inputs,
|
148 |
+
max_new_tokens=max_new_tokens,
|
149 |
+
repetition_penalty=repetition_penalty,
|
150 |
+
do_sample=True,
|
151 |
+
temperature=temperature,
|
152 |
+
top_p=top_p
|
153 |
+
)
|
154 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
155 |
+
return response
|
156 |
+
else:
|
157 |
+
return "Image input not supported in this implementation."
|
158 |
|
159 |
+
# Define Gradio interface components
|
160 |
+
max_new_tokens = gr.Slider(
|
161 |
+
minimum=2048,
|
162 |
+
maximum=16000,
|
163 |
+
value=4096,
|
164 |
+
step=64,
|
165 |
+
interactive=True,
|
166 |
+
label="Maximum number of new tokens to generate",
|
167 |
+
)
|
168 |
+
repetition_penalty = gr.Slider(
|
169 |
+
minimum=0.01,
|
170 |
+
maximum=5.0,
|
171 |
+
value=1,
|
172 |
+
step=0.01,
|
173 |
+
interactive=True,
|
174 |
+
label="Repetition penalty",
|
175 |
+
info="1.0 is equivalent to no penalty",
|
176 |
+
)
|
177 |
+
decoding_strategy = gr.Radio(
|
178 |
+
[
|
179 |
+
"Greedy",
|
180 |
+
"Top P Sampling",
|
181 |
+
],
|
182 |
+
value="Top P Sampling",
|
183 |
+
label="Decoding strategy",
|
184 |
+
interactive=True,
|
185 |
+
info="Higher values are equivalent to sampling more low-probability tokens.",
|
186 |
+
)
|
187 |
+
temperature = gr.Slider(
|
188 |
+
minimum=0.0,
|
189 |
+
maximum=2.0,
|
190 |
+
value=0.5,
|
191 |
+
step=0.05,
|
192 |
+
visible=True,
|
193 |
+
interactive=True,
|
194 |
+
label="Sampling temperature",
|
195 |
+
info="Higher values will produce more diverse outputs.",
|
196 |
+
)
|
197 |
+
top_p = gr.Slider(
|
198 |
+
minimum=0.01,
|
199 |
+
maximum=0.99,
|
200 |
+
value=0.9,
|
201 |
+
step=0.01,
|
202 |
+
visible=True,
|
203 |
+
interactive=True,
|
204 |
+
label="Top P",
|
205 |
+
info="Higher values are equivalent to sampling more low-probability tokens.",
|
206 |
+
)
|
207 |
|
208 |
+
# Create a chatbot interface
|
209 |
+
chatbot = gr.Chatbot(
|
210 |
+
label="OpenGPT-4o-Chatty",
|
211 |
+
show_copy_button=True,
|
212 |
+
likeable=True,
|
213 |
+
layout="panel"
|
214 |
+
)
|
|
|
|
|
215 |
|
216 |
+
# Define Gradio interface
|
217 |
+
def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
|
218 |
+
response = model_inference(
|
219 |
+
user_input,
|
220 |
+
history,
|
221 |
+
web_search,
|
222 |
+
decoding_strategy,
|
223 |
+
temperature,
|
224 |
+
max_new_tokens,
|
225 |
+
repetition_penalty,
|
226 |
+
top_p,
|
227 |
+
)
|
228 |
+
history.append((user_input, response))
|
229 |
+
return history, history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
+
# Create Gradio interface
|
232 |
+
interface = gr.Interface(
|
233 |
+
fn=chat_interface,
|
234 |
inputs=[
|
235 |
+
gr.Textbox(label="User Input"),
|
236 |
+
gr.State([]),
|
237 |
+
gr.Checkbox(label="Web Search", value=True),
|
238 |
+
decoding_strategy,
|
239 |
+
temperature,
|
240 |
+
max_new_tokens,
|
241 |
+
repetition_penalty,
|
242 |
+
top_p
|
243 |
+
],
|
244 |
+
outputs=[
|
245 |
+
chatbot,
|
246 |
+
gr.State([])
|
247 |
],
|
248 |
+
title="OpenGPT-4o-Chatty",
|
249 |
+
description="An AI assistant capable of insightful conversations and web search."
|
250 |
)
|
251 |
|
252 |
+
if __name__ == "__main__":
|
253 |
+
interface.launch()
|