Shreyas94 commited on
Commit
ddfb119
·
verified ·
1 Parent(s): bf295f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -99
app.py CHANGED
@@ -1,21 +1,52 @@
 
1
  import os
 
 
2
  import urllib
3
  import requests
 
 
4
  from typing import List, Dict, Union
 
 
5
  import torch
6
  import gradio as gr
7
- from bs4 import BeautifulSoup
8
  from huggingface_hub import InferenceClient
9
- from functools import lru_cache
10
- import logging
11
-
12
- # Set up logging
13
- logging.basicConfig(level=logging.DEBUG)
14
 
15
- # Set device to CUDA if available, otherwise CPU
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Extract text from webpage
 
 
 
 
19
  @lru_cache(maxsize=128)
20
  def extract_text_from_webpage(html_content):
21
  soup = BeautifulSoup(html_content, "html.parser")
@@ -24,120 +55,199 @@ def extract_text_from_webpage(html_content):
24
  visible_text = soup.get_text(strip=True)
25
  return visible_text
26
 
27
- # Perform a Google search and return the results
28
  def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
29
  escaped_term = urllib.parse.quote_plus(term)
30
  start = 0
31
  all_results = []
32
- max_chars_per_page = 8000 # Limit the number of characters from each webpage
33
 
34
  with requests.Session() as session:
35
  while start < num_results:
36
- try:
37
- resp = session.get(
38
- url="https://www.google.com/search",
39
- headers={"User-Agent": "Mozilla/5.0"},
40
- params={"q": term, "num": num_results - start, "hl": lang, "start": start, "safe": safe},
41
- timeout=timeout,
42
- verify=ssl_verify,
43
- )
44
- resp.raise_for_status()
45
- logging.debug(f"Raw HTML response from Google: {resp.text[:1000]}") # Log the first 1000 characters of the HTML
46
-
47
- soup = BeautifulSoup(resp.text, "html.parser")
48
- result_block = soup.find_all("div", attrs={"class": "g"})
49
- if not result_block:
50
- start += 1
51
- continue
52
-
53
- for result in result_block:
54
- link_tag = result.find("a", href=True)
55
- if link_tag:
56
- link = link_tag["href"]
57
- try:
58
- webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0"})
59
- webpage.raise_for_status()
60
- visible_text = extract_text_from_webpage(webpage.text)
61
- if len(visible_text) > max_chars_per_page:
62
- visible_text = visible_text[:max_chars_per_page] + "..."
63
- all_results.append({"link": link, "text": visible_text})
64
- except requests.exceptions.RequestException as e:
65
- logging.error(f"Error fetching or processing {link}: {e}")
66
- all_results.append({"link": link, "text": None})
67
- else:
68
- all_results.append({"link": None, "text": None})
69
- start += len(result_block)
70
- except requests.exceptions.RequestException as e:
71
- logging.error(f"Error during search request: {e}")
72
- break
73
- logging.debug(f"Web search results: {all_results}")
74
  return all_results
75
 
76
- # Format the prompt for the language model
77
  def format_prompt(user_prompt, chat_history):
78
  prompt = "<s>"
79
  for item in chat_history:
80
  if isinstance(item, tuple):
81
- prompt += f"[INST] {item[0]} [/INST]"
82
- prompt += f" {item[1]}</s>"
83
  else:
84
  prompt += f" [Image] "
85
  prompt += f"[INST] {user_prompt} [/INST]"
86
  return prompt
87
 
88
- # Model inference function
89
- def start_inference(prompt, enable_web_search):
90
- return next(model_inference(prompt, enable_web_search))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def model_inference(prompt, enable_web_search):
93
- for response in fetch_response(prompt, enable_web_search):
94
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- def fetch_response(prompt, enable_web_search):
97
- if enable_web_search:
98
- # Perform web search and generate text based on the retrieved results
99
- web_results = search(prompt)
100
- if not web_results:
101
- web2 = "No results found."
102
- else:
103
- web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res['text']])
104
- logging.debug(f"Formatted web search results: {web2}")
105
 
106
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
107
- generate_kwargs = dict(max_new_tokens=4000, do_sample=True)
108
- formatted_prompt = format_prompt(
109
- f"""You are OpenGPT 4o... [USER] {prompt} [WEB] {web2} [OpenGPT 4o]""",
110
- [(prompt, web2)])
111
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
112
- output = ""
113
- for response in stream:
114
- if not response.token.text == "</s>":
115
- output += response.token.text
116
- yield output
117
- else:
118
- # Use the microsoft/Phi-3-mini-4k-instruct model for generating text based on user prompts
119
- client = InferenceClient("microsoft/Phi-3-mini-4k-instruct")
120
- generate_kwargs = dict(max_new_tokens=5000, do_sample=True)
121
- formatted_prompt = format_prompt(f"""You are OpenGPT 4o... [USER] {prompt} [OpenGPT 4o]""", [(prompt, )])
122
- logging.debug(f"Formatted prompt without web search: {formatted_prompt}")
123
-
124
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
125
- output = ""
126
- for response in stream:
127
- if not response.token.text == "</s>":
128
- output += response.token.text
129
- yield output
130
 
131
- # Create a chatbot interface with a Fetch button
132
- chatbot = gr.Interface(
133
- fn=start_inference,
134
  inputs=[
135
- gr.Textbox(label="User Prompt", placeholder="Enter your prompt here..."),
136
- gr.Checkbox(label="Enable Web Search", value=False)
 
 
 
 
 
 
 
 
 
 
137
  ],
138
- outputs=gr.Textbox(label="Response", placeholder="Responses will appear here..."),
139
- live=True
140
  )
141
 
142
- # Launch the Gradio interface
143
- chatbot.launch()
 
1
+ # Import necessary libraries
2
  import os
3
+ import time
4
+ import copy
5
  import urllib
6
  import requests
7
+ import random
8
+ from threading import Thread
9
  from typing import List, Dict, Union
10
+ from functools import lru_cache
11
+ from bs4 import BeautifulSoup
12
  import torch
13
  import gradio as gr
14
+ from transformers import TextIteratorStreamer, AutoModelForSeq2SeqLM, AutoTokenizer
15
  from huggingface_hub import InferenceClient
 
 
 
 
 
16
 
17
+ # Define device and load model and tokenizer
18
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
+
23
+ # Set system prompt
24
+ SYSTEM_PROMPT = [
25
+ {
26
+ "role": "system",
27
+ "content": [
28
+ {
29
+ "type": "text",
30
+ "text": """You are OpenGPT 4o, an exceptionally capable and versatile AI assistant. Designed to assist human users through insightful conversations, your key attributes include intelligence and knowledge, image generation and perception, and providing reliable information. Always ensure a seamless and enjoyable experience for the user.""",
31
+ },
32
+ ],
33
+ },
34
+ {
35
+ "role": "assistant",
36
+ "content": [
37
+ {
38
+ "type": "text",
39
+ "text": "Hello, I'm OpenGPT 4o. How can I help you today?",
40
+ },
41
+ ],
42
+ }
43
+ ]
44
 
45
+ # Function to check if a turn in the chat history only contains media
46
+ def turn_is_pure_media(turn):
47
+ return turn[1] is None
48
+
49
+ # Function to extract visible text from HTML content
50
  @lru_cache(maxsize=128)
51
  def extract_text_from_webpage(html_content):
52
  soup = BeautifulSoup(html_content, "html.parser")
 
55
  visible_text = soup.get_text(strip=True)
56
  return visible_text
57
 
58
+ # Function to perform a Google search and return the results
59
  def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
60
  escaped_term = urllib.parse.quote_plus(term)
61
  start = 0
62
  all_results = []
63
+ max_chars_per_page = 8000
64
 
65
  with requests.Session() as session:
66
  while start < num_results:
67
+ resp = session.get(
68
+ url="https://www.google.com/search",
69
+ headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
70
+ params={
71
+ "q": term,
72
+ "num": num_results - start,
73
+ "hl": lang,
74
+ "start": start,
75
+ "safe": safe,
76
+ },
77
+ timeout=timeout,
78
+ verify=ssl_verify,
79
+ )
80
+ resp.raise_for_status()
81
+ soup = BeautifulSoup(resp.text, "html.parser")
82
+ result_block = soup.find_all("div", attrs={"class": "g"})
83
+ if not result_block:
84
+ start += 1
85
+ continue
86
+ for result in result_block:
87
+ link = result.find("a", href=True)
88
+ if link:
89
+ link = link["href"]
90
+ try:
91
+ webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"})
92
+ webpage.raise_for_status()
93
+ visible_text = extract_text_from_webpage(webpage.text)
94
+ if len(visible_text) > max_chars_per_page:
95
+ visible_text = visible_text[:max_chars_per_page] + "..."
96
+ all_results.append({"link": link, "text": visible_text})
97
+ except requests.exceptions.RequestException as e:
98
+ print(f"Error fetching or processing {link}: {e}")
99
+ all_results.append({"link": link, "text": None})
100
+ else:
101
+ all_results.append({"link": None, "text": None})
102
+ start += len(result_block)
 
 
103
  return all_results
104
 
105
+ # Function to format the prompt for the language model
106
  def format_prompt(user_prompt, chat_history):
107
  prompt = "<s>"
108
  for item in chat_history:
109
  if isinstance(item, tuple):
110
+ prompt += f"[INST] {item[0]} [/INST] {item[1]}</s>"
 
111
  else:
112
  prompt += f" [Image] "
113
  prompt += f"[INST] {user_prompt} [/INST]"
114
  return prompt
115
 
116
+ # Function for model inference
117
+ def model_inference(
118
+ user_prompt,
119
+ chat_history,
120
+ web_search,
121
+ decoding_strategy,
122
+ temperature,
123
+ max_new_tokens,
124
+ repetition_penalty,
125
+ top_p,
126
+ ):
127
+ if not user_prompt["files"]:
128
+ if web_search:
129
+ web_results = search(user_prompt["text"])
130
+ web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
131
+ formatted_prompt = format_prompt(f"{user_prompt['text']} [WEB] {web2}", chat_history)
132
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
133
+ outputs = model.generate(
134
+ **inputs,
135
+ max_new_tokens=max_new_tokens,
136
+ repetition_penalty=repetition_penalty,
137
+ do_sample=True,
138
+ temperature=temperature,
139
+ top_p=top_p
140
+ )
141
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
142
+ return response
143
+ else:
144
+ formatted_prompt = format_prompt(user_prompt["text"], chat_history)
145
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
146
+ outputs = model.generate(
147
+ **inputs,
148
+ max_new_tokens=max_new_tokens,
149
+ repetition_penalty=repetition_penalty,
150
+ do_sample=True,
151
+ temperature=temperature,
152
+ top_p=top_p
153
+ )
154
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
155
+ return response
156
+ else:
157
+ return "Image input not supported in this implementation."
158
 
159
+ # Define Gradio interface components
160
+ max_new_tokens = gr.Slider(
161
+ minimum=2048,
162
+ maximum=16000,
163
+ value=4096,
164
+ step=64,
165
+ interactive=True,
166
+ label="Maximum number of new tokens to generate",
167
+ )
168
+ repetition_penalty = gr.Slider(
169
+ minimum=0.01,
170
+ maximum=5.0,
171
+ value=1,
172
+ step=0.01,
173
+ interactive=True,
174
+ label="Repetition penalty",
175
+ info="1.0 is equivalent to no penalty",
176
+ )
177
+ decoding_strategy = gr.Radio(
178
+ [
179
+ "Greedy",
180
+ "Top P Sampling",
181
+ ],
182
+ value="Top P Sampling",
183
+ label="Decoding strategy",
184
+ interactive=True,
185
+ info="Higher values are equivalent to sampling more low-probability tokens.",
186
+ )
187
+ temperature = gr.Slider(
188
+ minimum=0.0,
189
+ maximum=2.0,
190
+ value=0.5,
191
+ step=0.05,
192
+ visible=True,
193
+ interactive=True,
194
+ label="Sampling temperature",
195
+ info="Higher values will produce more diverse outputs.",
196
+ )
197
+ top_p = gr.Slider(
198
+ minimum=0.01,
199
+ maximum=0.99,
200
+ value=0.9,
201
+ step=0.01,
202
+ visible=True,
203
+ interactive=True,
204
+ label="Top P",
205
+ info="Higher values are equivalent to sampling more low-probability tokens.",
206
+ )
207
 
208
+ # Create a chatbot interface
209
+ chatbot = gr.Chatbot(
210
+ label="OpenGPT-4o-Chatty",
211
+ show_copy_button=True,
212
+ likeable=True,
213
+ layout="panel"
214
+ )
 
 
215
 
216
+ # Define Gradio interface
217
+ def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
218
+ response = model_inference(
219
+ user_input,
220
+ history,
221
+ web_search,
222
+ decoding_strategy,
223
+ temperature,
224
+ max_new_tokens,
225
+ repetition_penalty,
226
+ top_p,
227
+ )
228
+ history.append((user_input, response))
229
+ return history, history
 
 
 
 
 
 
 
 
 
 
230
 
231
+ # Create Gradio interface
232
+ interface = gr.Interface(
233
+ fn=chat_interface,
234
  inputs=[
235
+ gr.Textbox(label="User Input"),
236
+ gr.State([]),
237
+ gr.Checkbox(label="Web Search", value=True),
238
+ decoding_strategy,
239
+ temperature,
240
+ max_new_tokens,
241
+ repetition_penalty,
242
+ top_p
243
+ ],
244
+ outputs=[
245
+ chatbot,
246
+ gr.State([])
247
  ],
248
+ title="OpenGPT-4o-Chatty",
249
+ description="An AI assistant capable of insightful conversations and web search."
250
  )
251
 
252
+ if __name__ == "__main__":
253
+ interface.launch()