Spaces:
Sleeping
Sleeping
Fix indentation in app.py and remove redundant feature list from the Gradio interface description for improved clarity.
4a200c6
| import gradio as gr | |
| import os | |
| import json | |
| import re | |
| from typing import Iterator, Dict, Any, List, Optional | |
| from openai import OpenAI | |
| from openai.types.chat import ChatCompletionChunk | |
| # Load abstracts content once at startup | |
| def load_abstracts_content(): | |
| """Load the abstracts content once at startup to avoid reading file on every request.""" | |
| try: | |
| with open("abstracts.md", "r", encoding="utf-8") as f: | |
| return f.read() | |
| except FileNotFoundError: | |
| return "Abstracts database not found." | |
| # Load abstracts content globally | |
| ABSTRACTS_CONTENT = load_abstracts_content() | |
| # Load full paper texts | |
| def load_paper_texts(): | |
| """Load all paper texts from the Papers directory and create a mapping from abstracts filenames.""" | |
| papers = {} | |
| papers_dir = "Papers" | |
| if not os.path.exists(papers_dir): | |
| return {} | |
| # Create a mapping from abstracts filenames to actual file content | |
| for filename in os.listdir(papers_dir): | |
| if filename.endswith('.txt'): | |
| filepath = os.path.join(papers_dir, filename) | |
| try: | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| # Store with the filename as key | |
| papers[filename] = content | |
| except Exception as e: | |
| papers[filename] = f"Error loading paper: {str(e)}" | |
| return papers | |
| # Load paper texts globally | |
| PAPER_TEXTS = load_paper_texts() | |
| def normalize_filename(filename): | |
| """Normalize filename for better matching.""" | |
| # Remove .txt extension and normalize | |
| if filename.endswith('.txt'): | |
| filename = filename[:-4] | |
| # Convert to lowercase and remove special characters | |
| filename = re.sub(r'[^\w\s]', '', filename.lower()) | |
| # Normalize whitespace | |
| filename = ' '.join(filename.split()) | |
| return filename | |
| def find_matching_paper_file(query_terms, papers_dict): | |
| """Find the best matching paper file based on query terms.""" | |
| query_normalized = normalize_filename(' '.join(query_terms)) | |
| best_match = None | |
| best_score = 0 | |
| for filename in papers_dict.keys(): | |
| filename_normalized = normalize_filename(filename) | |
| # Calculate match score | |
| score = 0 | |
| # Exact substring match | |
| if query_normalized in filename_normalized or filename_normalized in query_normalized: | |
| score += 10 | |
| # Word overlap | |
| query_words = set(query_normalized.split()) | |
| filename_words = set(filename_normalized.split()) | |
| overlap = len(query_words.intersection(filename_words)) | |
| score += overlap * 2 | |
| # Partial word matches | |
| for query_word in query_words: | |
| for filename_word in filename_words: | |
| if query_word in filename_word or filename_word in query_word: | |
| score += 1 | |
| if score > best_score: | |
| best_score = score | |
| best_match = filename | |
| return best_match if best_score > 0 else None | |
| def get_relevant_papers_content(query, max_papers=5): | |
| """Get relevant paper content based on user query.""" | |
| query_terms = query.lower().split() | |
| relevant_papers = [] | |
| for filename, content in PAPER_TEXTS.items(): | |
| title = filename[:-4] if filename.endswith('.txt') else filename | |
| title_lower = title.lower() | |
| # Calculate relevance score | |
| score = 0 | |
| for term in query_terms: | |
| if term in title_lower: | |
| score += 2 | |
| if term in content.lower(): | |
| score += 1 | |
| if score > 0: | |
| relevant_papers.append((filename, content, score)) | |
| # Sort by relevance score and return top papers | |
| relevant_papers.sort(key=lambda x: x[2], reverse=True) | |
| return relevant_papers[:max_papers] | |
| def get_full_paper_content(title, max_chars=12000): | |
| """Get full paper content for a specific title.""" | |
| for filename, content in PAPER_TEXTS.items(): | |
| if title.lower() in filename.lower() or filename.lower() in title.lower(): | |
| return content[:max_chars] + "..." if len(content) > max_chars else content | |
| return "Paper not found." | |
| def get_paper_summary(title): | |
| """Get a structured summary of a paper.""" | |
| content = get_full_paper_content(title) | |
| if content == "Paper not found.": | |
| return content | |
| # Extract key sections | |
| sections = { | |
| 'abstract': '', | |
| 'introduction': '', | |
| 'methodology': '', | |
| 'results': '', | |
| 'conclusions': '' | |
| } | |
| lines = content.split('\n') | |
| current_section = None | |
| for line in lines: | |
| line_lower = line.lower().strip() | |
| # Detect section headers | |
| if any(keyword in line_lower for keyword in ['abstract', 'introduction', 'method', 'methodology', 'results', 'conclusion']): | |
| if 'abstract' in line_lower: | |
| current_section = 'abstract' | |
| elif 'introduction' in line_lower: | |
| current_section = 'introduction' | |
| elif 'method' in line_lower: | |
| current_section = 'methodology' | |
| elif 'result' in line_lower: | |
| current_section = 'results' | |
| elif 'conclusion' in line_lower: | |
| current_section = 'conclusions' | |
| # Add content to current section | |
| if current_section and line.strip(): | |
| sections[current_section] += line + '\n' | |
| # Create structured summary | |
| summary = f"# {title}\n\n" | |
| for section, content in sections.items(): | |
| if content.strip(): | |
| summary += f"## {section.title()}\n{content.strip()}\n\n" | |
| return summary | |
| # Get API key with better error handling | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| print("โ ๏ธ Warning: OPENAI_API_KEY environment variable not set!") | |
| client = None | |
| else: | |
| client = OpenAI( | |
| api_key=api_key, | |
| timeout=60.0, | |
| max_retries=3 | |
| ) | |
| # Available models | |
| AVAILABLE_MODELS = { | |
| "GPT-4o-mini": "gpt-4o-mini", | |
| "GPT-4o": "gpt-4o", | |
| "GPT-3.5 Turbo": "gpt-3.5-turbo" | |
| } | |
| # Define the tool for fetching papers | |
| FETCH_PAPERS_TOOL = { | |
| "type": "function", | |
| "function": { | |
| "name": "fetch_papers", | |
| "description": "Fetch full text content of research papers by their filenames. Use this when you need detailed information, full text, conclusions, methodology, or specific quotes from papers.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "filenames": { | |
| "type": "array", | |
| "items": { | |
| "type": "string" | |
| }, | |
| "description": "List of paper filenames to fetch (e.g., ['The Labor Market Effects of Generativ.txt', 'AI Companions Reduce Loneliness.txt'])" | |
| } | |
| }, | |
| "required": ["filenames"] | |
| } | |
| } | |
| } | |
| def fetch_papers(filenames: List[str]) -> Dict[str, str]: | |
| """ | |
| Fetch full paper texts by filenames. | |
| Returns a dictionary mapping filename to content. | |
| """ | |
| papers = {} | |
| papers_dir = "Papers" | |
| if not os.path.exists(papers_dir): | |
| return {"error": "Papers directory not found"} | |
| for filename in filenames: | |
| # Ensure .txt extension | |
| if not filename.endswith('.txt'): | |
| filename += '.txt' | |
| filepath = os.path.join(papers_dir, filename) | |
| if os.path.exists(filepath): | |
| try: | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| papers[filename] = f.read() | |
| except Exception as e: | |
| papers[filename] = f"Error loading paper: {str(e)}" | |
| else: | |
| papers[filename] = f"Paper not found: {filename}" | |
| return papers | |
| def extract_conclusion_from_paper(content: str) -> str: | |
| """Extract the conclusion section from a paper's content.""" | |
| conclusion_patterns = [ | |
| "conclusion and future works", | |
| "conclusion and future work", | |
| "conclusions", | |
| "conclusion", | |
| "summary and conclusions", | |
| "discussion and conclusions" | |
| ] | |
| lines = content.split('\n') | |
| conclusion_start = -1 | |
| for i, line in enumerate(lines): | |
| line_lower = line.lower().strip() | |
| if any(pattern in line_lower for pattern in conclusion_patterns): | |
| if (line.isupper() or | |
| line.strip().endswith(':') or | |
| len(line.strip()) < 100 or | |
| line.strip().startswith('Conclusion')): | |
| conclusion_start = i | |
| break | |
| if conclusion_start != -1: | |
| conclusion_lines = [] | |
| for line in lines[conclusion_start:]: | |
| line_stripped = line.strip() | |
| if (line_stripped.lower().startswith('acknowledgments') or | |
| line_stripped.lower().startswith('references') or | |
| line_stripped.startswith('--- Page')): | |
| break | |
| conclusion_lines.append(line) | |
| return '\n'.join(conclusion_lines) | |
| # Fallback: return the last 1000 characters | |
| return content[-1000:] if len(content) > 1000 else content | |
| def truncate_conversation_history(messages: list, max_tokens: int = 8000) -> list: | |
| """Truncate conversation history to stay within token limits.""" | |
| if len(messages) <= 3: | |
| return messages | |
| system_message = messages[0] | |
| conversation_messages = messages[1:] | |
| while len(conversation_messages) > 6: | |
| conversation_messages = conversation_messages[2:] | |
| return [system_message] + conversation_messages | |
| def respond( | |
| message: str, | |
| history: list[tuple[str, str]], | |
| model_name: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ) -> Iterator[str]: | |
| """ | |
| Generate a response using OpenAI's models with function calling. | |
| """ | |
| if not client: | |
| yield "โ Error: OpenAI API key not configured." | |
| return | |
| if not message.strip(): | |
| yield "Please enter a message to start the conversation." | |
| return | |
| # Get relevant full paper content based on user query | |
| relevant_papers_content = get_relevant_papers_content(message) | |
| # Check if user is asking for a specific paper (e.g., "show me the full paper about pigs") | |
| specific_paper_content = "" | |
| conclusion_content = "" | |
| paper_summary_content = "" | |
| if any(keyword in message.lower() for keyword in ["full paper", "complete paper", "entire paper", "show me the paper", "read the paper", "summarize", "summary"]): | |
| # Try to find specific paper content | |
| for filename, content in PAPER_TEXTS.items(): | |
| title = filename[:-4] if filename.endswith('.txt') else filename | |
| if any(term in title.lower() for term in message.lower().split()): | |
| if any(keyword in message.lower() for keyword in ["summarize", "summary"]): | |
| paper_summary_content = get_paper_summary(title) | |
| else: | |
| specific_paper_content = get_full_paper_content(title) | |
| break | |
| # Check if user is asking for conclusions specifically | |
| if any(keyword in message.lower() for keyword in ["conclusion", "conclusions", "what's the conclusion", "what is the conclusion"]): | |
| for filename, content in PAPER_TEXTS.items(): | |
| title = filename[:-4] if filename.endswith('.txt') else filename | |
| if any(term in title.lower() for term in message.lower().split()): | |
| conclusion_content = extract_conclusion_from_paper(content) | |
| break | |
| # Initialize messages with a comprehensive system prompt | |
| system_prompt = f"""You are an AI chatbot designed to help users explore and analyze AI research papers. | |
| You have access to: | |
| 1. An abstracts database with summaries of research papers | |
| 2. Full paper texts for detailed analysis | |
| 3. A tool to fetch additional paper content when needed | |
| ABSTRACTS DATABASE: | |
| {ABSTRACTS_CONTENT} | |
| RELEVANT PAPERS CONTENT: | |
| {chr(10).join([f"Paper: {filename}{chr(10)}Content: {content[:3000]}..." for filename, content, score in relevant_papers_content])} | |
| SPECIFIC PAPER CONTENT: | |
| {specific_paper_content if specific_paper_content else "None"} | |
| CONCLUSION CONTENT: | |
| {conclusion_content if conclusion_content else "None"} | |
| PAPER SUMMARY: | |
| {paper_summary_content if paper_summary_content else "None"} | |
| INSTRUCTIONS: | |
| - Use the abstracts for general questions and overview | |
| - Use full paper content when users ask for specific details, conclusions, or complete papers | |
| - Use the fetch_papers tool when you need additional paper content | |
| - Provide accurate, detailed responses based on the actual paper content | |
| - When referencing papers, use their actual titles from the filenames | |
| - Prioritize full paper content over abstracts when available""" | |
| messages = [{"role": "system", "content": system_prompt}] | |
| # Add conversation history | |
| for user_msg, assistant_msg in history: | |
| if user_msg and user_msg.strip(): | |
| messages.append({"role": "user", "content": user_msg.strip()}) | |
| if assistant_msg and assistant_msg.strip(): | |
| messages.append({"role": "assistant", "content": assistant_msg.strip()}) | |
| # Add current user message | |
| messages.append({"role": "user", "content": message.strip()}) | |
| # Truncate if needed | |
| messages = truncate_conversation_history(messages) | |
| try: | |
| model = AVAILABLE_MODELS.get(model_name, "gpt-4o-mini") | |
| # Initial response with tool support | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| tools=[FETCH_PAPERS_TOOL], | |
| tool_choice="auto", | |
| stream=True | |
| ) | |
| # Collect the response and handle tool calls | |
| full_response = "" | |
| tool_calls = [] | |
| current_tool_call = None | |
| for chunk in response: | |
| if hasattr(chunk.choices[0], 'delta'): | |
| delta = chunk.choices[0].delta | |
| # Handle regular content | |
| if delta.content is not None: | |
| full_response += delta.content | |
| yield full_response | |
| # Handle tool calls | |
| if delta.tool_calls: | |
| for tool_call_chunk in delta.tool_calls: | |
| if tool_call_chunk.id: | |
| # New tool call | |
| if current_tool_call: | |
| tool_calls.append(current_tool_call) | |
| current_tool_call = { | |
| "id": tool_call_chunk.id, | |
| "type": "function", | |
| "function": { | |
| "name": tool_call_chunk.function.name if tool_call_chunk.function else "", | |
| "arguments": "" | |
| } | |
| } | |
| if current_tool_call and tool_call_chunk.function: | |
| if tool_call_chunk.function.arguments: | |
| current_tool_call["function"]["arguments"] += tool_call_chunk.function.arguments | |
| # Add final tool call if exists | |
| if current_tool_call: | |
| tool_calls.append(current_tool_call) | |
| # Process tool calls if any | |
| if tool_calls: | |
| # Add the assistant's message with tool calls | |
| messages.append({ | |
| "role": "assistant", | |
| "content": full_response if full_response else None, | |
| "tool_calls": tool_calls | |
| }) | |
| # Execute tool calls | |
| for tool_call in tool_calls: | |
| function_name = tool_call["function"]["name"] | |
| if function_name == "fetch_papers": | |
| try: | |
| # Parse arguments | |
| arguments = json.loads(tool_call["function"]["arguments"]) | |
| filenames = arguments.get("filenames", []) | |
| # Fetch papers | |
| papers_content = fetch_papers(filenames) | |
| # Add tool response to messages | |
| tool_response = { | |
| "role": "tool", | |
| "tool_call_id": tool_call["id"], | |
| "content": json.dumps(papers_content) | |
| } | |
| messages.append(tool_response) | |
| except Exception as e: | |
| tool_response = { | |
| "role": "tool", | |
| "tool_call_id": tool_call["id"], | |
| "content": f"Error: {str(e)}" | |
| } | |
| messages.append(tool_response) | |
| # Get final response with tool results | |
| final_response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stream=True | |
| ) | |
| # Stream the final response | |
| final_text = "" | |
| for chunk in final_response: | |
| if hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta.content is not None: | |
| final_text += chunk.choices[0].delta.content | |
| yield full_response + "\n\n" + final_text if full_response else final_text | |
| except Exception as e: | |
| error_message = f"Error: {str(e)}" | |
| if "api_key" in str(e).lower(): | |
| error_message = "Error: Invalid or missing OpenAI API key." | |
| elif "quota" in str(e).lower(): | |
| error_message = "Error: API quota exceeded." | |
| elif "rate" in str(e).lower(): | |
| error_message = "Error: Rate limit exceeded." | |
| yield error_message | |
| def chat_fn(message, history, model_name, max_tokens, temperature, top_p): | |
| """Handle the entire chat interaction.""" | |
| if not message.strip(): | |
| return history | |
| history.append([message, ""]) | |
| for response in respond(message, history[:-1], model_name, max_tokens, temperature, top_p): | |
| history[-1][1] = response | |
| yield history | |
| def clear_history() -> tuple: | |
| """Clear the conversation history.""" | |
| return [], "" | |
| # Create the Gradio interface | |
| with gr.Blocks( | |
| title="๐ AI Research Paper Chatbot", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐ AI Research Paper Chatbot | |
| Chat with an AI assistant that can intelligently retrieve and analyze research papers. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| show_label=False, | |
| container=True, | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Type your message here...", | |
| show_label=False, | |
| container=False, | |
| scale=9 | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear_btn = gr.Button("Clear", variant="secondary", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### โ๏ธ Settings") | |
| model_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value="GPT-4o", | |
| label="Model", | |
| info="Select the AI model to use" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=1, | |
| maximum=4096, | |
| value=1024, | |
| step=1, | |
| label="Max Tokens", | |
| info="Maximum response length" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Creativity level" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.05, | |
| label="Top-p", | |
| info="Response diversity" | |
| ) | |
| gr.Markdown("### ๐ก Examples") | |
| example_btn1 = gr.Button("What papers discuss AI's impact on employment?", size="sm") | |
| example_btn2 = gr.Button("Show me the full paper about AI companions", size="sm") | |
| example_btn3 = gr.Button("Compare findings on AI in education", size="sm") | |
| # Event handlers | |
| msg.submit( | |
| chat_fn, | |
| [msg, chatbot, model_dropdown, max_tokens_slider, temperature_slider, top_p_slider], | |
| [chatbot], | |
| show_progress=True | |
| ).then( | |
| lambda: "", | |
| outputs=[msg] | |
| ) | |
| submit_btn.click( | |
| chat_fn, | |
| [msg, chatbot, model_dropdown, max_tokens_slider, temperature_slider, top_p_slider], | |
| [chatbot], | |
| show_progress=True | |
| ).then( | |
| lambda: "", | |
| outputs=[msg] | |
| ) | |
| clear_btn.click(clear_history, outputs=[chatbot, msg]) | |
| # Example handlers | |
| example_btn1.click(lambda: "What papers discuss AI's impact on employment?", outputs=msg) | |
| example_btn2.click(lambda: "Show me the full paper about AI companions", outputs=msg) | |
| example_btn3.click(lambda: "Compare findings on AI in education", outputs=msg) | |
| if __name__ == "__main__": | |
| if not os.getenv("OPENAI_API_KEY"): | |
| print("โ ๏ธ Warning: OPENAI_API_KEY environment variable not set!") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| quiet=False | |
| ) | |