import gradio as gr from PyPDF2 import PdfReader from bs4 import BeautifulSoup import requests from io import BytesIO from transformers import AutoTokenizer import os from openai import OpenAI # Cache for tokenizers to avoid reloading tokenizer_cache = {} # Function to fetch paper information from OpenReview def fetch_paper_info_neurips(paper_id): url = f"https://openreview.net/forum?id={paper_id}" response = requests.get(url) if response.status_code != 200: return None, None html_content = response.content soup = BeautifulSoup(html_content, 'html.parser') # Extract title title_tag = soup.find('h2', class_='citation_title') title = title_tag.get_text(strip=True) if title_tag else 'Title not found' # Extract authors authors = [] author_div = soup.find('div', class_='forum-authors') if author_div: author_tags = author_div.find_all('a') authors = [tag.get_text(strip=True) for tag in author_tags] author_list = ', '.join(authors) if authors else 'Authors not found' # Extract abstract abstract_div = soup.find('strong', text='Abstract:') if abstract_div: abstract_paragraph = abstract_div.find_next_sibling('div') abstract = abstract_paragraph.get_text(strip=True) if abstract_paragraph else 'Abstract not found' else: abstract = 'Abstract not found' # Construct preamble in Markdown # preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n**Abstract:**\n{abstract}" preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n" return preamble def fetch_paper_content(paper_id): try: # Construct the URL url = f"https://openreview.net/pdf?id={paper_id}" # Fetch the PDF response = requests.get(url) response.raise_for_status() # Raise an exception for HTTP errors # Read the PDF content pdf_content = BytesIO(response.content) reader = PdfReader(pdf_content) # Extract text from the PDF text = "" for page in reader.pages: text += page.extract_text() return text # Return full text; truncation will be handled later except Exception as e: print(f"An error occurred: {e}") return None def paper_chat_tab(paper_id): with gr.Blocks() as demo: with gr.Column(): # Textbox to display the paper title and authors content = gr.Markdown(value="") # Preamble message to hint the user gr.Markdown("**Note:** Providing your own sambanova token can help you avoid rate limits.") # Input for Hugging Face token hf_token_input = gr.Textbox( label="Enter your sambanova token (optional)", type="password", placeholder="Enter your sambanova token to avoid rate limits" ) models = [ # "Meta-Llama-3.1-8B-Instruct", "Meta-Llama-3.1-70B-Instruct", # "Meta-Llama-3.1-405B-Instruct", ] default_model = models[0] # Dropdown for selecting the model model_dropdown = gr.Dropdown( label="Select Model", choices=models, value=default_model ) # State to store the paper content paper_content = gr.State() # Create a column for each model, only visible if it's the default model columns = [] for model_name in models: column = gr.Column(visible=(model_name == default_model)) with column: chatbot = create_chat_interface(model_name, paper_content, hf_token_input) columns.append(column) gr.HTML( '') gr.Markdown("**Note:** This model is supported by SambaNova.") # Update visibility of columns based on the selected model def update_columns(selected_model): visibility = [] for model_name in models: is_visible = model_name == selected_model visibility.append(gr.update(visible=is_visible)) return visibility model_dropdown.change( fn=update_columns, inputs=model_dropdown, outputs=columns, api_name=False, queue=False, ) # Function to update the content Markdown and paper_content when paper ID or model changes def update_paper_info(paper_id, selected_model): preamble = fetch_paper_info_neurips(paper_id) text = fetch_paper_content(paper_id) if text is None: return preamble, None return preamble, text # Update paper content when paper ID or model changes paper_id.change( fn=update_paper_info, inputs=[paper_id, model_dropdown], outputs=[content, paper_content] ) model_dropdown.change( fn=update_paper_info, inputs=[paper_id, model_dropdown], outputs=[content, paper_content], queue=False, ) return demo def create_chat_interface(model_name, paper_content, hf_token_input): # Load tokenizer and cache it if model_name not in tokenizer_cache: # Load the tokenizer from Hugging Face # tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", token=os.getenv("HF_TOKEN")) tokenizer_cache[model_name] = tokenizer else: tokenizer = tokenizer_cache[model_name] max_total_tokens = 50000 # Maximum tokens allowed # Define the function to handle the chat def get_fn(message, history, paper_content_value, hf_token_value): # Include the paper content as context if paper_content_value: context = f"The following is the content of the paper:\n{paper_content_value}\n\n" else: context = "" # Tokenize the context context_tokens = tokenizer.encode(context) context_token_length = len(context_tokens) # Prepare the messages without context messages = [] message_tokens_list = [] total_tokens = context_token_length # Start with context tokens for user_msg, assistant_msg in history: # Tokenize user message user_tokens = tokenizer.encode(user_msg) messages.append({"role": "user", "content": user_msg}) message_tokens_list.append(len(user_tokens)) total_tokens += len(user_tokens) # Tokenize assistant message if assistant_msg: assistant_tokens = tokenizer.encode(assistant_msg) messages.append({"role": "assistant", "content": assistant_msg}) message_tokens_list.append(len(assistant_tokens)) total_tokens += len(assistant_tokens) # Tokenize the new user message message_tokens = tokenizer.encode(message) messages.append({"role": "user", "content": message}) message_tokens_list.append(len(message_tokens)) total_tokens += len(message_tokens) # Check if total tokens exceed the maximum allowed tokens if total_tokens > max_total_tokens: # Attempt to truncate the context first available_tokens = max_total_tokens - (total_tokens - context_token_length) if available_tokens > 0: # Truncate the context to fit the available tokens truncated_context_tokens = context_tokens[:available_tokens] context = tokenizer.decode(truncated_context_tokens) context_token_length = available_tokens total_tokens = total_tokens - len(context_tokens) + context_token_length else: # Not enough space for context; remove it context = "" total_tokens -= context_token_length context_token_length = 0 # If total tokens still exceed the limit, truncate the message history while total_tokens > max_total_tokens and len(messages) > 1: # Remove the oldest message removed_message = messages.pop(0) removed_tokens = message_tokens_list.pop(0) total_tokens -= removed_tokens # Rebuild the final messages list including the (possibly truncated) context final_messages = [] if context: final_messages.append({"role": "system", "content": context}) final_messages.extend(messages) # Use the Hugging Face token if provided api_key = hf_token_value or os.getenv("SAMBANOVA_API_KEY") if not api_key: raise ValueError("API token is not provided.") # Initialize the OpenAI client client = OpenAI( base_url="https://api.sambanova.ai/v1/", api_key=api_key, ) try: # Create the chat completion completion = client.chat.completions.create( model=model_name, messages=final_messages, stream=True, ) response_text = "" for chunk in completion: delta = chunk.choices[0].delta.content or "" response_text += delta yield response_text except Exception as e: error_message = f"Error: {str(e)}" yield error_message # Create the ChatInterface chat_interface = gr.ChatInterface( fn=get_fn, chatbot=gr.Chatbot( label="Chatbot", scale=1, height=400, autoscroll=True ), additional_inputs=[paper_content, hf_token_input], # examples=["What are the main findings of this paper?", "Explain the methodology used in this research."] ) return chat_interface