Spaces:
Running
Running
import gradio as gr | |
from PyPDF2 import PdfReader | |
from bs4 import BeautifulSoup | |
import requests | |
from io import BytesIO | |
from transformers import AutoTokenizer | |
import os | |
from openai import OpenAI | |
# Cache for tokenizers to avoid reloading | |
tokenizer_cache = {} | |
# Function to fetch paper information from OpenReview | |
def fetch_paper_info_neurips(paper_id): | |
url = f"https://openreview.net/forum?id={paper_id}" | |
response = requests.get(url) | |
if response.status_code != 200: | |
return None, None | |
html_content = response.content | |
soup = BeautifulSoup(html_content, 'html.parser') | |
# Extract title | |
title_tag = soup.find('h2', class_='citation_title') | |
title = title_tag.get_text(strip=True) if title_tag else 'Title not found' | |
# Extract authors | |
authors = [] | |
author_div = soup.find('div', class_='forum-authors') | |
if author_div: | |
author_tags = author_div.find_all('a') | |
authors = [tag.get_text(strip=True) for tag in author_tags] | |
author_list = ', '.join(authors) if authors else 'Authors not found' | |
# Extract abstract | |
abstract_div = soup.find('strong', text='Abstract:') | |
if abstract_div: | |
abstract_paragraph = abstract_div.find_next_sibling('div') | |
abstract = abstract_paragraph.get_text(strip=True) if abstract_paragraph else 'Abstract not found' | |
else: | |
abstract = 'Abstract not found' | |
# Construct preamble in Markdown | |
# preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n**Abstract:**\n{abstract}" | |
preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n" | |
return preamble | |
def fetch_paper_content(paper_id): | |
try: | |
# Construct the URL | |
url = f"https://openreview.net/pdf?id={paper_id}" | |
# Fetch the PDF | |
response = requests.get(url) | |
response.raise_for_status() # Raise an exception for HTTP errors | |
# Read the PDF content | |
pdf_content = BytesIO(response.content) | |
reader = PdfReader(pdf_content) | |
# Extract text from the PDF | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
return text # Return full text; truncation will be handled later | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return None | |
def paper_chat_tab(paper_id): | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
# Textbox to display the paper title and authors | |
content = gr.Markdown(value="") | |
# Preamble message to hint the user | |
gr.Markdown("**Note:** Providing your own sambanova token can help you avoid rate limits.") | |
# Input for Hugging Face token | |
hf_token_input = gr.Textbox( | |
label="Enter your sambanova token (optional)", | |
type="password", | |
placeholder="Enter your sambanova token to avoid rate limits" | |
) | |
models = [ | |
# "Meta-Llama-3.1-8B-Instruct", | |
"Meta-Llama-3.1-70B-Instruct", | |
# "Meta-Llama-3.1-405B-Instruct", | |
] | |
default_model = models[0] | |
# Dropdown for selecting the model | |
model_dropdown = gr.Dropdown( | |
label="Select Model", | |
choices=models, | |
value=default_model | |
) | |
# State to store the paper content | |
paper_content = gr.State() | |
# Create a column for each model, only visible if it's the default model | |
columns = [] | |
for model_name in models: | |
column = gr.Column(visible=(model_name == default_model)) | |
with column: | |
chatbot = create_chat_interface(model_name, paper_content, hf_token_input) | |
columns.append(column) | |
gr.HTML( | |
'<img src="https://venturebeat.com/wp-content/uploads/2020/02/SambaNovaLogo_H_F.jpg" width="100px" />') | |
gr.Markdown("**Note:** This model is supported by SambaNova.") | |
# Update visibility of columns based on the selected model | |
def update_columns(selected_model): | |
visibility = [] | |
for model_name in models: | |
is_visible = model_name == selected_model | |
visibility.append(gr.update(visible=is_visible)) | |
return visibility | |
model_dropdown.change( | |
fn=update_columns, | |
inputs=model_dropdown, | |
outputs=columns, | |
api_name=False, | |
queue=False, | |
) | |
# Function to update the content Markdown and paper_content when paper ID or model changes | |
def update_paper_info(paper_id, selected_model): | |
preamble = fetch_paper_info_neurips(paper_id) | |
text = fetch_paper_content(paper_id) | |
if text is None: | |
return preamble, None | |
return preamble, text | |
# Update paper content when paper ID or model changes | |
paper_id.change( | |
fn=update_paper_info, | |
inputs=[paper_id, model_dropdown], | |
outputs=[content, paper_content] | |
) | |
model_dropdown.change( | |
fn=update_paper_info, | |
inputs=[paper_id, model_dropdown], | |
outputs=[content, paper_content], | |
queue=False, | |
) | |
return demo | |
def create_chat_interface(model_name, paper_content, hf_token_input): | |
# Load tokenizer and cache it | |
if model_name not in tokenizer_cache: | |
# Load the tokenizer from Hugging Face | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", | |
token=os.getenv("HF_TOKEN")) | |
tokenizer_cache[model_name] = tokenizer | |
else: | |
tokenizer = tokenizer_cache[model_name] | |
max_total_tokens = 50000 # Maximum tokens allowed | |
# Define the function to handle the chat | |
def get_fn(message, history, paper_content_value, hf_token_value): | |
# Include the paper content as context | |
if paper_content_value: | |
context = f"The following is the content of the paper:\n{paper_content_value}\n\n" | |
else: | |
context = "" | |
# Tokenize the context | |
context_tokens = tokenizer.encode(context) | |
context_token_length = len(context_tokens) | |
# Prepare the messages without context | |
messages = [] | |
message_tokens_list = [] | |
total_tokens = context_token_length # Start with context tokens | |
for user_msg, assistant_msg in history: | |
# Tokenize user message | |
user_tokens = tokenizer.encode(user_msg) | |
messages.append({"role": "user", "content": user_msg}) | |
message_tokens_list.append(len(user_tokens)) | |
total_tokens += len(user_tokens) | |
# Tokenize assistant message | |
if assistant_msg: | |
assistant_tokens = tokenizer.encode(assistant_msg) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
message_tokens_list.append(len(assistant_tokens)) | |
total_tokens += len(assistant_tokens) | |
# Tokenize the new user message | |
message_tokens = tokenizer.encode(message) | |
messages.append({"role": "user", "content": message}) | |
message_tokens_list.append(len(message_tokens)) | |
total_tokens += len(message_tokens) | |
# Check if total tokens exceed the maximum allowed tokens | |
if total_tokens > max_total_tokens: | |
# Attempt to truncate the context first | |
available_tokens = max_total_tokens - (total_tokens - context_token_length) | |
if available_tokens > 0: | |
# Truncate the context to fit the available tokens | |
truncated_context_tokens = context_tokens[:available_tokens] | |
context = tokenizer.decode(truncated_context_tokens) | |
context_token_length = available_tokens | |
total_tokens = total_tokens - len(context_tokens) + context_token_length | |
else: | |
# Not enough space for context; remove it | |
context = "" | |
total_tokens -= context_token_length | |
context_token_length = 0 | |
# If total tokens still exceed the limit, truncate the message history | |
while total_tokens > max_total_tokens and len(messages) > 1: | |
# Remove the oldest message | |
removed_message = messages.pop(0) | |
removed_tokens = message_tokens_list.pop(0) | |
total_tokens -= removed_tokens | |
# Rebuild the final messages list including the (possibly truncated) context | |
final_messages = [] | |
if context: | |
final_messages.append({"role": "system", "content": context}) | |
final_messages.extend(messages) | |
# Use the Hugging Face token if provided | |
api_key = hf_token_value or os.getenv("SAMBANOVA_API_KEY") | |
if not api_key: | |
raise ValueError("API token is not provided.") | |
# Initialize the OpenAI client | |
client = OpenAI( | |
base_url="https://api.sambanova.ai/v1/", | |
api_key=api_key, | |
) | |
try: | |
# Create the chat completion | |
completion = client.chat.completions.create( | |
model=model_name, | |
messages=final_messages, | |
stream=True, | |
) | |
response_text = "" | |
for chunk in completion: | |
delta = chunk.choices[0].delta.content or "" | |
response_text += delta | |
yield response_text | |
except Exception as e: | |
error_message = f"Error: {str(e)}" | |
yield error_message | |
# Create the ChatInterface | |
chat_interface = gr.ChatInterface( | |
fn=get_fn, | |
chatbot=gr.Chatbot( | |
label="Chatbot", | |
scale=1, | |
height=400, | |
autoscroll=True | |
), | |
additional_inputs=[paper_content, hf_token_input], | |
# examples=["What are the main findings of this paper?", "Explain the methodology used in this research."] | |
) | |
return chat_interface | |