import streamlit as st import pandas as pd from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from peft import PeftModel, PeftConfig, LoraConfig import torch class Summarizer: def __init__(self): # Base model for individual summaries self.base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base") self.tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base") try: # Create LoRA config lora_config = LoraConfig( r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="SEQ_2_SEQ_LM", target_modules=["q_proj", "v_proj"], inference_mode=True ) # Load base model for fine-tuning base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base") # Load the PEFT model from local path in the Space self.finetuned_model = PeftModel.from_pretrained( base_model_for_finetuned, "models", # Points to the models directory in the Space config=lora_config ) self.finetuned_model.eval() except Exception as e: st.error(f"Error loading fine-tuned model: {str(e)}") raise def summarize_text(self, text, max_length=150, use_finetuned=False): inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512) if use_finetuned: generation_kwargs = { "max_length": max_length, "num_beams": 4, "length_penalty": 2.0, "early_stopping": True } summary_ids = self.finetuned_model.generate(**inputs, **generation_kwargs) else: summary_ids = self.base_model.generate( inputs["input_ids"], max_length=max_length, num_beams=4, length_penalty=2.0, early_stopping=True ) return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) def process_excel(self, file, question): # Read Excel file df = pd.read_excel(file) # Extract abstracts and generate individual summaries using base model summaries = [] with st.progress(0.0): for idx, row in df.iterrows(): if pd.notna(row['Abstract']): # Update progress bar progress = (idx + 1) / len(df) st.progress(progress) paper_info = { 'title': row['Article Title'], 'authors': row['Authors'] if pd.notna(row['Authors']) else '', 'source': row['Source Title'] if pd.notna(row['Source Title']) else '', 'year': row['Publication Year'] if pd.notna(row['Publication Year']) else '', 'doi': row['DOI'] if pd.notna(row['DOI']) else '', 'document_type': row['Document Type'] if pd.notna(row['Document Type']) else '', 'times_cited': row['Times Cited, WoS Core'] if pd.notna(row['Times Cited, WoS Core']) else 0, 'open_access': row['Open Access Designations'] if pd.notna(row['Open Access Designations']) else '', 'research_areas': row['Research Areas'] if pd.notna(row['Research Areas']) else '', 'summary': self.summarize_text(row['Abstract'], use_finetuned=False) } summaries.append(paper_info) # Generate overall summary using fine-tuned model combined_summaries = " ".join([s['summary'] for s in summaries]) overall_summary = self.summarize_text(combined_summaries, max_length=250, use_finetuned=True) return summaries, overall_summary # Set page config st.set_page_config(page_title="Research Paper Summarizer", layout="wide") # Initialize session state if 'summarizer' not in st.session_state: st.session_state['summarizer'] = None if 'summaries' not in st.session_state: st.session_state['summaries'] = None if 'overall_summary' not in st.session_state: st.session_state['overall_summary'] = None # App title and description st.title("Research Paper Summarizer") st.write("Upload an Excel file with research papers to generate summaries") # Sidebar for inputs with st.sidebar: st.header("Input Options") uploaded_file = st.file_uploader("Choose an Excel file", type=['xlsx', 'xls']) question = st.text_area("Enter your research question") # Generate button generate_button = st.button("Generate Summaries", type="primary", use_container_width=True) if generate_button and uploaded_file and question: try: # Initialize summarizer if not already done if st.session_state['summarizer'] is None: st.session_state['summarizer'] = Summarizer() with st.spinner("Generating summaries..."): summaries, overall_summary = st.session_state['summarizer'].process_excel(uploaded_file, question) st.session_state['summaries'] = summaries st.session_state['overall_summary'] = overall_summary st.success("Summaries generated successfully!") except Exception as e: st.error(f"An error occurred: {str(e)}") elif generate_button: st.warning("Please upload a file and enter a research question.") # Main content area if st.session_state['overall_summary']: st.header("Overall Summary") st.write(st.session_state['overall_summary']) if st.session_state['summaries']: st.header("Individual Paper Summaries") # Sorting options col1, col2 = st.columns([2, 3]) with col1: sort_by = st.selectbox( "Sort by", ["Year", "Citations", "Source", "Type", "Access", "Research Areas"], index=0 ) # Sort summaries based on selection summaries = st.session_state['summaries'] if sort_by == "Year": summaries.sort(key=lambda x: x['year'], reverse=True) elif sort_by == "Citations": summaries.sort(key=lambda x: x['times_cited'], reverse=True) elif sort_by == "Source": summaries.sort(key=lambda x: x['source']) elif sort_by == "Type": summaries.sort(key=lambda x: x['document_type']) elif sort_by == "Access": summaries.sort(key=lambda x: x['open_access']) elif sort_by == "Research Areas": summaries.sort(key=lambda x: x['research_areas']) # Display summaries in expandable sections for paper in summaries: with st.expander(f"{paper['title']} ({paper['year']})"): col1, col2 = st.columns([2, 1]) with col1: st.write("**Summary:**") st.write(paper['summary']) with col2: st.write(f"**Authors:** {paper['authors']}") st.write(f"**Source:** {paper['source']}") st.write(f"**DOI:** {paper['doi']}") st.write(f"**Document Type:** {paper['document_type']}") st.write(f"**Times Cited:** {paper['times_cited']}") st.write(f"**Open Access:** {paper['open_access']}") st.write(f"**Research Areas:** {paper['research_areas']}")