import streamlit as st import pandas as pd import torch import re from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel from text_processing import TextProcessor import gc from pathlib import Path import concurrent.futures import time import nltk from nltk.tokenize import sent_tokenize from concurrent.futures import ThreadPoolExecutor # Add this import nltk.download('punkt') # Configure page st.set_page_config( page_title="Biomedical Papers Analysis", page_icon="đŦ", layout="wide" ) # Initialize session state if 'relevant_papers' not in st.session_state: st.session_state.relevant_papers = None if 'relevance_scores' not in st.session_state: st.session_state.relevance_scores = None if 'processed_data' not in st.session_state: st.session_state.processed_data = None if 'summaries' not in st.session_state: st.session_state.summaries = None if 'text_processor' not in st.session_state: st.session_state.text_processor = None if 'processing_started' not in st.session_state: st.session_state.processing_started = False if 'focused_summary_generated' not in st.session_state: st.session_state.focused_summary_generated = False if 'current_model' not in st.session_state: st.session_state.current_model = None if 'current_tokenizer' not in st.session_state: st.session_state.current_tokenizer = None if 'model_type' not in st.session_state: st.session_state.model_type = None if 'focused_summary' not in st.session_state: st.session_state.focused_summary = None # TextProcessor class definition try: from text_processing import TextProcessor except ImportError: class TextProcessor: def find_most_relevant_abstracts(self, question, abstracts, top_k=5): return { 'top_indices': list(range(min(top_k, len(abstracts)))), 'scores': [1.0] * min(top_k, len(abstracts)) } def load_model(model_type): """Load appropriate model based on type with proper memory management""" try: # Clear any existing cached data gc.collect() torch.cuda.empty_cache() device = "cpu" # Force CPU usage if model_type == "summarize": # Load the new fine-tuned model directly model = AutoModelForSeq2SeqLM.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models", torch_dtype=torch.float32 ).to(device) tokenizer = AutoTokenizer.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models" ) else: # question_focused base_model = AutoModelForSeq2SeqLM.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models", torch_dtype=torch.float32 ).to(device) model = PeftModel.from_pretrained( base_model, "pendar02/biobart-finetune", is_trainable=False ).to(device) tokenizer = AutoTokenizer.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models" ) model.eval() return model, tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") raise def get_model(model_type): """Get model from session state or load if needed""" try: if (st.session_state.current_model is None or st.session_state.model_type != model_type): # Clean up existing model if st.session_state.current_model is not None: cleanup_model(st.session_state.current_model, st.session_state.current_tokenizer) # Load new model model, tokenizer = load_model(model_type) st.session_state.current_model = model st.session_state.current_tokenizer = tokenizer st.session_state.model_type = model_type return st.session_state.current_model, st.session_state.current_tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") st.session_state.processing_started = False return None, None def cleanup_model(model, tokenizer): """Properly cleanup model resources""" try: del model del tokenizer torch.cuda.empty_cache() gc.collect() except Exception: pass @st.cache_data def process_excel(uploaded_file): """Process uploaded Excel file""" try: df = pd.read_excel(uploaded_file) required_columns = ['Abstract', 'Article Title', 'Authors', 'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases'] # Check required columns first missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: st.error("â Missing required columns: " + ", ".join(missing_columns)) st.error("Please ensure your Excel file contains all required columns.") return None # Only proceed with validation if all required columns exist if len(df) > 5: st.error("â Your file contains more than 5 papers. Please upload a file with maximum 5 papers.") return None # Now safe to validate structure as we know columns exist is_valid, messages = validate_excel_structure(df) if not is_valid: for msg in messages: st.error(f"â {msg}") return None return df[required_columns] except Exception as e: st.error(f"â Error reading file: {str(e)}") st.error("Please check if your file is in the correct Excel format (.xlsx or .xls)") return None def validate_excel_structure(df): """Validate the structure and content of the Excel file""" validation_messages = [] # Check for minimum content if len(df) == 0: validation_messages.append("File contains no data") return False, validation_messages try: # Check publication year format - this is useful for sorting/filtering df['Publication Year'] = pd.to_numeric(df['Publication Year'], errors='coerce') if df['Publication Year'].isna().any(): validation_messages.append("Some publication years are invalid. Please ensure all years are in numeric format (e.g., 2024)") else: years = df['Publication Year'].dropna() if len(years) > 0: if years.min() < 1900 or years.max() > 2025: validation_messages.append("Publication years must be between 1900 and 2025") # For short abstracts - just show a warning short_abstracts = df['Abstract'].fillna('').astype(str).str.len() < 50 if short_abstracts.any(): st.warning("âšī¸ Some abstracts are quite short, but will still be processed") except Exception as e: validation_messages.append(f"Error checking data format: {str(e)}") return len(validation_messages) == 0, validation_messages def preprocess_text(text): """Clean biomedical text by handling common formatting issues and standardizing structure.""" if not isinstance(text, str) or not text.strip(): return text # Remove extra whitespace text = ' '.join(text.split()) # Roman numeral conversion roman_map = {'i': '1', 'ii': '2', 'iii': '3', 'iv': '4', 'v': '5', 'vi': '6', 'vii': '7', 'viii': '8', 'ix': '9', 'x': '10'} def replace_roman(match): roman = match.group(1).lower() return f"({roman_map.get(roman, roman)})" text = re.sub(r'\(([ivx]+)\)', replace_roman, text) # Clean enumerated lists for roman in roman_map: text = re.sub(f"\\b{roman}\\)", f"{roman_map[roman]})", text, flags=re.IGNORECASE) # Standardize section headers section_patterns = { r'\b(?:introduction|purpose|background|objectives?|context)\s*:?\s*': 'Background: ', r'\b(?:materials?\s+and\s+methods?|methods?|approach|study\s+design)\s*:?\s*': 'Methods: ', r'\b(?:results?|findings?|observations?)\s*:?\s*': 'Results: ', r'\b(?:conclusions?|summary|final\s+remarks?)\s*:?\s*': 'Conclusions: ', r'\b(?:results?\s+and\s+conclusions?)\s*:?\s*(?=.*?:)': '', # Remove if followed by another section r'\b(?:results?\s*:\s*and\s*conclusions?\s*:)': 'Results: ' # Fix malformed combination } for pattern, replacement in section_patterns.items(): text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) # Ensure complete sentences in sections text = re.sub(r'(?<=:)\s*([^.!?\n]*?)(?=\s*(?:[A-Z][^:]*:|$))', lambda m: f" {m.group(1)}." if m.group(1) and not m.group(1).strip().endswith('.') else m.group(0), text) # Fix truncated sentences text = re.sub(r'(?<=:)\s*([^.!?\n]*?)\s*(?=[A-Z][^:]*:)', lambda m: f" {m.group(1)}." if m.group(1) else "", text) # Clean formatting text = re.sub(r'[\r\n]+', ' ', text) text = re.sub(r'\s*:\s*', ': ', text) text = re.sub(r'\s+', ' ', text) text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', ' ', text) text = re.sub(r'âĸ|\*|â |âĄ|â|â', '', text) text = re.sub(r'\\n|\\r', ' ', text) text = re.sub(r'\s*\(\s*', ' (', text) text = re.sub(r'\s*\)\s*', ') ', text) # Fix statistical notations text = re.sub(r'p\s*[<=>]\s*0\.\d+', lambda m: m.group().replace(' ', ''), text) text = re.sub(r'(?<=\d)\s*%', '%', text) # Fix abbreviations spacing text = re.sub(r'(?<=\w)vs\.(?=\w)', 'vs. ', text) text = re.sub(r'(?<=\w)et\s+al\.(?=\w)', 'et al. ', text) # Remove repeated punctuation text = re.sub(r'([.!?])\1+', r'\1', text) # Final cleanup text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', ' ', text) text = text.strip() if not text.endswith('.'): text += '.' return text # """Enhanced text preprocessing with better section handling and prompt removal.""" # if not isinstance(text, str) or not text.strip(): # return text # # Remove prompt leakage # prompt_patterns = [ # r'Generate a structured summary addressing this question:.*?(?=\w+:)', # r'Focus on key findings and methods\.', # r'is a structured summary addressing this question:' # ] # for pattern in prompt_patterns: # text = re.sub(pattern, '', text, flags=re.IGNORECASE) # # Clean section headers more aggressively # section_patterns = { # r'\b(?:introduction|purpose|background|objectives?|context)\s*:?\s*': 'Background: ', # r'\b(?:materials?\s+and\s+methods?|methods?|approach|study\s+design)\s*:?\s*': 'Methods: ', # r'\b(?:results?|findings?|observations?)\s*:?\s*': 'Results: ', # r'\b(?:conclusions?|summary|final\s+remarks?)\s*:?\s*': 'Conclusions: ' # } # # Apply section normalization # for pattern, replacement in section_patterns.items(): # text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) # # Remove combined section headers # combined_headers = [ # r'\bmethods?\s+and\s+conclusions?\b', # r'\bresults?\s+and\s+conclusions?\b', # r'\bmaterials?\s+and\s+methods?\b' # ] # for pattern in combined_headers: # text = re.sub(pattern, 'Methods:', text, flags=re.IGNORECASE) # # Clean up sentences # sentences = text.split('.') # cleaned_sentences = [] # for sentence in sentences: # # Remove redundant section references # sentence = re.sub(r'\b(?:first|second|third|fourth|fifth)\s+sections?\b', '', sentence, flags=re.IGNORECASE) # # Remove comparative phrases about section details # sentence = re.sub(r'\b(?:more|less)\s+detailed\s+than.*', '', sentence, flags=re.IGNORECASE) # if sentence.strip(): # cleaned_sentences.append(sentence.strip()) # # Rejoin and format # text = '. '.join(cleaned_sentences) # text = re.sub(r'\s+', ' ', text) # Remove extra spaces # text = re.sub(r'\s*:\s*', ': ', text) # Fix spacing around colons # return text.strip() def generate_focused_summary(question, abstracts, model, tokenizer): formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts if abstract.strip()] abstracts_content = " [SEP] ".join(formatted_abstracts) prompt = f""" Provide a factual summary structured as: - Background: Context and origin only if present - Methods: Key procedures and approaches - Results: Specific findings with numbers - Conclusions: Main implications Requirements: - Present sections sequentially - Merge related points within sections - Complete all sentences - Avoid repeating section headers - Use original terminology Content: {abstracts_content} """ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): summary_ids = model.generate( **{ "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_length": 512, "min_length": 200, "num_beams": 4, "length_penalty": 2.0, "no_repeat_ngram_size": 3, "temperature": 0.7, "do_sample": False } ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return post_process_summary(summary) def post_process_summary(summary): """Post-process summary with improved section handling and formatting.""" if not summary: return summary valid_sections = ['Background', 'Methods', 'Results', 'Conclusions'] sections = {} current_section = None current_content = [] # Pre-clean section headers summary = re.sub(r'\b(?:results?\s*:\s*and\s*conclusions?\s*:)', 'Results:', summary, flags=re.IGNORECASE) summary = re.sub(r'\bresults?\s*and\s*conclusions?\s*:', 'Results:', summary, flags=re.IGNORECASE) # Process line by line lines = [line.strip() for line in summary.split('.') if line.strip()] for i, line in enumerate(lines): section_match = None for section in valid_sections: if re.match(fr'\b{section}:', line, re.IGNORECASE): section_match = section break if section_match: if current_section: content = ' '.join(current_content) if content: sections[current_section] = content current_section = section_match content = re.sub(fr'\b{section_match}:\s*', '', line, flags=re.IGNORECASE) current_content = [content] if content else [] elif current_section: # Prevent section header splitting if not any(sect.lower() in line.lower() for sect in valid_sections): current_content.append(line) if current_section and current_content: sections[current_section] = ' '.join(current_content) # Format sections formatted_sections = [] for section in valid_sections: if section in sections: content = sections[section].strip() if content: # Complete truncated sentences if not re.search(r'[.!?]$', content): if len(content.split()) >= 3: # Only complete if substantial content += '.' # Ensure capitalization content = content[0].upper() + content[1:] # Fix double periods content = re.sub(r'\.+', '.', content) formatted_sections.append(f"{section}: {content}") return ' '.join(formatted_sections) def process_papers_in_batches(df, model, tokenizer, batch_size=2): """Process papers in batches for better efficiency""" abstracts = df['Abstract'].tolist() summaries = [] with ThreadPoolExecutor(max_workers=4) as executor: # Parallel processing future_to_batch = {executor.submit(generate_focused_summary, "Focus on key findings and methods.", [abstract], model, tokenizer): abstract for abstract in abstracts} for future in future_to_batch: summaries.append(future.result()) return summaries def create_filter_controls(df, sort_column): """Create appropriate filter controls based on the selected column""" filtered_df = df.copy() if sort_column == 'Publication Year': # Year range slider year_min = int(df['Publication Year'].min()) year_max = int(df['Publication Year'].max()) col1, col2 = st.columns(2) with col1: start_year = st.number_input('From Year', min_value=year_min, max_value=year_max, value=year_min) with col2: end_year = st.number_input('To Year', min_value=year_min, max_value=year_max, value=year_max) filtered_df = filtered_df[ (filtered_df['Publication Year'] >= start_year) & (filtered_df['Publication Year'] <= end_year) ] elif sort_column == 'Authors': # Multi-select for authors unique_authors = sorted(set( author.strip() for authors in df['Authors'].dropna() for author in authors.split(';') )) selected_authors = st.multiselect( 'Select Authors', unique_authors ) if selected_authors: filtered_df = filtered_df[ filtered_df['Authors'].apply( lambda x: any(author in str(x) for author in selected_authors) ) ] elif sort_column == 'Source Title': # Multi-select for source titles unique_sources = sorted(df['Source Title'].unique()) selected_sources = st.multiselect( 'Select Sources', unique_sources ) if selected_sources: filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] elif sort_column == 'Article Title': # Only alphabetical sorting, no filtering pass return filtered_df def main(): st.title("đŦ Biomedical Papers Analysis") st.info(""" **đ File Upload Requirements:** - Excel file (.xlsx or .xls) with **maximum 5 papers** - Must contain these columns: âĸ Abstract âĸ Article Title âĸ Authors âĸ Source Title âĸ Publication Year âĸ DOI âĸ Times Cited, All Databases """) # File upload section uploaded_file = st.file_uploader( "Upload Excel file containing papers (max 5 papers)", type=['xlsx', 'xls'], help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" ) # Question input - moved up but hidden initially question_container = st.empty() question = "" if uploaded_file is not None: # Process Excel file if st.session_state.processed_data is None: with st.spinner("Processing file..."): df = process_excel(uploaded_file) if df is not None: df = df.dropna(subset=["Abstract"]) if len(df) > 0: st.session_state.processed_data = df st.success(f"â Successfully loaded {len(df)} papers with abstracts") else: st.error("â No valid papers found after processing. Please check your file.") if st.session_state.processed_data is not None: df = st.session_state.processed_data st.write(f"đ Loaded {len(df)} papers with abstracts") # Get question before processing with question_container: question = st.text_input( "Enter your research question (optional):", help="If provided, a question-focused summary will be generated after individual summaries" ) # Single button for both processes if not st.session_state.get('processing_started', False): if st.button("Start Analysis"): st.session_state.processing_started = True # Show processing status and results if st.session_state.get('processing_started', False): # Individual Summaries Section st.header("đ Individual Paper Summaries") # Generate summaries if not already done if st.session_state.summaries is None: try: with st.spinner("Generating individual paper summaries..."): model, tokenizer = get_model("summarize") if model is None or tokenizer is None: reset_processing_state() return start_time = time.time() st.session_state.summaries = process_papers_in_batches( df, model, tokenizer, batch_size=2 ) end_time = time.time() st.write(f"Processing time: {end_time - start_time:.2f} seconds") except Exception as e: st.error(f"Error generating summaries: {str(e)}") reset_processing_state() # Display summaries with improved sorting and filtering if st.session_state.summaries is not None: col1, col2 = st.columns(2) with col1: sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] sort_column = st.selectbox("Sort/Filter by:", sort_options) with col2: if sort_column == 'Article Title': ascending = st.radio( "Sort order", ["A to Z", "Z to A"], horizontal=True ) == "A to Z" elif sort_column == 'Times Cited': ascending = st.radio( "Sort order", ["Most cited first", "Least cited first"], horizontal=True ) == "Least cited first" else: ascending = True # Default for other columns # Create display dataframe display_df = df.copy() display_df['Summary'] = st.session_state.summaries display_df['Publication Year'] = display_df['Publication Year'].astype(int) display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) # Apply filters filtered_df = create_filter_controls(display_df, sort_column) # Apply sorting if sort_column == 'Times Cited': sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) elif sort_column == 'Article Title': sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) else: sorted_df = filtered_df # Show number of filtered results if len(sorted_df) != len(display_df): st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") # Apply custom styling st.markdown(""" """, unsafe_allow_html=True) # Display papers using the filtered and sorted dataframe for _, row in sorted_df.iterrows(): paper_info_cols = st.columns([1, 1]) with paper_info_cols[0]: # PAPER column st.markdown('