Spaces:
Sleeping
Sleeping
import os | |
import urllib | |
import requests | |
from bs4 import BeautifulSoup | |
import torch | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import logging | |
import feedparser | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Define device and load model and tokenizer | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" | |
# Load model and tokenizer | |
try: | |
logger.debug("Attempting to load the model and tokenizer") | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
logger.debug("Model and tokenizer loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading model and tokenizer: {e}") | |
model = None | |
tokenizer = None | |
# Function to fetch news from Google News RSS feed | |
def fetch_news(term, num_results=2): | |
logger.debug(f"Fetching news for term: {term}") | |
encoded_term = urllib.parse.quote(term) | |
url = f"https://news.google.com/rss/search?q={encoded_term}" | |
feed = feedparser.parse(url) | |
results = [] | |
for entry in feed.entries[:num_results]: | |
results.append({"link": entry.link, "text": entry.title}) | |
logger.debug(f"Fetched news results: {results}") | |
return results | |
# Function to perform a Google search and return the results | |
def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None): | |
logger.debug(f"Starting search for term: {term}") | |
escaped_term = urllib.parse.quote_plus(term) | |
start = 0 | |
all_results = [] | |
max_chars_per_page = 8000 | |
with requests.Session() as session: | |
while start < num_results: | |
try: | |
resp = session.get( | |
url="https://www.google.com/search", | |
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, | |
params={ | |
"q": term, | |
"num": num_results - start, | |
"hl": lang, | |
"start": start, | |
"safe": safe, | |
}, | |
timeout=timeout, | |
verify=ssl_verify, | |
) | |
resp.raise_for_status() | |
soup = BeautifulSoup(resp.text, "html.parser") | |
result_block = soup.find_all("div", attrs={"class": "g"}) | |
if not result_block: | |
start += 1 | |
continue | |
for result in result_block: | |
link = result.find("a", href=True) | |
if link: | |
link = link["href"] | |
try: | |
webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}) | |
webpage.raise_for_status() | |
visible_text = extract_text_from_webpage(webpage.text) | |
if len(visible_text) > max_chars_per_page: | |
visible_text = visible_text[:max_chars_per_page] + "..." | |
all_results.append({"link": link, "text": visible_text}) | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Error fetching or processing {link}: {e}") | |
all_results.append({"link": link, "text": None}) | |
else: | |
all_results.append({"link": None, "text": None}) | |
start += len(result_block) | |
except Exception as e: | |
logger.error(f"Error during search: {e}") | |
break | |
logger.debug(f"Search results: {all_results}") | |
return all_results | |
# Function to extract visible text from HTML content | |
def extract_text_from_webpage(html_content): | |
soup = BeautifulSoup(html_content, "html.parser") | |
for tag in soup(["script", "style", "header", "footer", "nav"]): | |
tag.extract() | |
visible_text = soup.get_text(strip=True) | |
return visible_text | |
# Function to format the prompt for the language model | |
def format_prompt(user_prompt, chat_history): | |
logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}") | |
prompt = "" | |
for item in chat_history: | |
prompt += f"User: {item[0]}\nAssistant: {item[1]}\n" | |
prompt += f"User: {user_prompt}\nAssistant:" | |
logger.debug(f"Formatted prompt: {prompt}") | |
return prompt | |
# Function for model inference | |
def model_inference( | |
user_prompt, | |
chat_history, | |
web_search, | |
temperature, | |
max_new_tokens, | |
repetition_penalty, | |
top_p, | |
tokenizer # Pass tokenizer as an argument | |
): | |
logger.debug(f"Starting model inference with user prompt: {user_prompt}, chat history: {chat_history}, web_search: {web_search}") | |
if not isinstance(user_prompt, dict): | |
logger.error("Invalid input format. Expected a dictionary.") | |
return "Invalid input format. Expected a dictionary." | |
if "files" not in user_prompt: | |
user_prompt["files"] = [] | |
if not user_prompt["files"]: | |
if web_search: | |
logger.debug("Performing news search") | |
news_results = fetch_news(user_prompt["text"]) | |
news2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in news_results]) | |
formatted_prompt = format_prompt(f"{user_prompt['text']} [NEWS] {news2}", chat_history) | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) | |
if model: | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
else: | |
response = "Model is not available. Please try again later." | |
logger.debug(f"Model response: {response}") | |
return response | |
else: | |
formatted_prompt = format_prompt(user_prompt["text"], chat_history) | |
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) | |
if model: | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
else: | |
response = "Model is not available. Please try again later." | |
logger.debug(f"Model response: {response}") | |
return response | |
else: | |
return "Image input not supported in this implementation." | |
# Define Gradio interface components | |
max_new_tokens = gr.Slider( | |
minimum=1, | |
maximum=16000, | |
value=2048, | |
step=64, | |
interactive=True, | |
label="Maximum number of new tokens to generate", | |
) | |
repetition_penalty = gr.Slider( | |
minimum=0.01, | |
maximum=5.0, | |
value=1, | |
step=0.01, | |
interactive=True, | |
label="Repetition penalty", | |
info="1.0 is equivalent to no penalty", | |
) | |
decoding_strategy = gr.Radio( | |
[ | |
"Greedy", | |
"Top P Sampling", | |
], | |
value="Top P Sampling", | |
label="Decoding strategy", | |
interactive=True, | |
info="Higher values are equivalent to sampling more low-probability tokens.", | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.5, | |
step=0.05, | |
visible=True, | |
interactive=True, | |
label="Sampling temperature", | |
info="Higher values will produce more diverse outputs.", | |
) | |
top_p = gr.Slider( | |
minimum=0.01, | |
maximum=0.99, | |
value=0.9, | |
step=0.01, | |
visible=True, | |
interactive=True, | |
label="Top P", | |
info="Higher values are equivalent to sampling more low-probability tokens.", | |
) | |
# Create a chatbot interface | |
chatbot = gr.Chatbot( | |
label="OpenGPT-4o-Chatty", | |
show_copy_button=True, | |
likeable=True, | |
layout="panel" | |
) | |
# Define Gradio interface | |
def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p): | |
# Ensure the tokenizer is accessible within the function scope | |
global tokenizer | |
# Wrap the user input in a dictionary as expected by the model_inference function | |
user_prompt = {"text": user_input, "files": []} | |
# Perform model inference | |
response = model_inference( | |
user_prompt=user_prompt, | |
chat_history=history, | |
web_search=web_search, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
repetition_penalty=repetition_penalty, | |
top_p=top_p, | |
tokenizer=tokenizer # Pass tokenizer to the model_inference function | |
) | |
# Update history with the user input and model response | |
history.append((user_input, response)) | |
# Return the response and updated history | |
return response, history | |
# Define the Gradio interface components | |
interface = gr.Interface( | |
fn=chat_interface, | |
inputs=[ | |
gr.Textbox(label="User Input", placeholder="Type your message here..."), | |
gr.State([]), # Initialize the chat history as an empty list | |
gr.Checkbox(label="Perform Web Search"), | |
gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"), | |
gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5), | |
gr.Slider(minimum=1, maximum=16000, step=64, label="Maximum number of new tokens to generate", value=2048), | |
gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1), | |
gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9) | |
], | |
outputs=[ | |
gr.Textbox(label="Assistant Response"), | |
gr.State([]) # Update the chat history | |
], | |
live=True | |
) | |
# Launch the Gradio interface | |
interface.launch() |