Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from peft import PeftModel, PeftConfig, LoraConfig | |
import torch | |
class Summarizer: | |
def __init__(self): | |
# Base model for individual summaries | |
self.base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base") | |
self.tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base") | |
try: | |
# Create LoRA config | |
lora_config = LoraConfig( | |
r=8, | |
lora_alpha=16, | |
lora_dropout=0.1, | |
bias="none", | |
task_type="SEQ_2_SEQ_LM", | |
target_modules=["q_proj", "v_proj"], | |
inference_mode=True | |
) | |
# Load base model for fine-tuning | |
base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base") | |
# Load the PEFT model from local path in the Space | |
self.finetuned_model = PeftModel.from_pretrained( | |
base_model_for_finetuned, | |
"models", # Points to the models directory in the Space | |
config=lora_config | |
) | |
self.finetuned_model.eval() | |
except Exception as e: | |
st.error(f"Error loading fine-tuned model: {str(e)}") | |
raise | |
def summarize_text(self, text, max_length=150, use_finetuned=False): | |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
if use_finetuned: | |
generation_kwargs = { | |
"max_length": max_length, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True | |
} | |
summary_ids = self.finetuned_model.generate(**inputs, **generation_kwargs) | |
else: | |
summary_ids = self.base_model.generate( | |
inputs["input_ids"], | |
max_length=max_length, | |
num_beams=4, | |
length_penalty=2.0, | |
early_stopping=True | |
) | |
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def process_excel(self, file, question): | |
# Read Excel file | |
df = pd.read_excel(file) | |
# Extract abstracts and generate individual summaries using base model | |
summaries = [] | |
with st.progress(0.0): | |
for idx, row in df.iterrows(): | |
if pd.notna(row['Abstract']): | |
# Update progress bar | |
progress = (idx + 1) / len(df) | |
st.progress(progress) | |
paper_info = { | |
'title': row['Article Title'], | |
'authors': row['Authors'] if pd.notna(row['Authors']) else '', | |
'source': row['Source Title'] if pd.notna(row['Source Title']) else '', | |
'year': row['Publication Year'] if pd.notna(row['Publication Year']) else '', | |
'doi': row['DOI'] if pd.notna(row['DOI']) else '', | |
'document_type': row['Document Type'] if pd.notna(row['Document Type']) else '', | |
'times_cited': row['Times Cited, WoS Core'] if pd.notna(row['Times Cited, WoS Core']) else 0, | |
'open_access': row['Open Access Designations'] if pd.notna(row['Open Access Designations']) else '', | |
'research_areas': row['Research Areas'] if pd.notna(row['Research Areas']) else '', | |
'summary': self.summarize_text(row['Abstract'], use_finetuned=False) | |
} | |
summaries.append(paper_info) | |
# Generate overall summary using fine-tuned model | |
combined_summaries = " ".join([s['summary'] for s in summaries]) | |
overall_summary = self.summarize_text(combined_summaries, max_length=250, use_finetuned=True) | |
return summaries, overall_summary | |
# Set page config | |
st.set_page_config(page_title="Research Paper Summarizer", layout="wide") | |
# Initialize session state | |
if 'summarizer' not in st.session_state: | |
st.session_state['summarizer'] = None | |
if 'summaries' not in st.session_state: | |
st.session_state['summaries'] = None | |
if 'overall_summary' not in st.session_state: | |
st.session_state['overall_summary'] = None | |
# App title and description | |
st.title("Research Paper Summarizer") | |
st.write("Upload an Excel file with research papers to generate summaries") | |
# Sidebar for inputs | |
with st.sidebar: | |
st.header("Input Options") | |
uploaded_file = st.file_uploader("Choose an Excel file", type=['xlsx', 'xls']) | |
question = st.text_area("Enter your research question") | |
# Generate button | |
generate_button = st.button("Generate Summaries", type="primary", use_container_width=True) | |
if generate_button and uploaded_file and question: | |
try: | |
# Initialize summarizer if not already done | |
if st.session_state['summarizer'] is None: | |
st.session_state['summarizer'] = Summarizer() | |
with st.spinner("Generating summaries..."): | |
summaries, overall_summary = st.session_state['summarizer'].process_excel(uploaded_file, question) | |
st.session_state['summaries'] = summaries | |
st.session_state['overall_summary'] = overall_summary | |
st.success("Summaries generated successfully!") | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
elif generate_button: | |
st.warning("Please upload a file and enter a research question.") | |
# Main content area | |
if st.session_state['overall_summary']: | |
st.header("Overall Summary") | |
st.write(st.session_state['overall_summary']) | |
if st.session_state['summaries']: | |
st.header("Individual Paper Summaries") | |
# Sorting options | |
col1, col2 = st.columns([2, 3]) | |
with col1: | |
sort_by = st.selectbox( | |
"Sort by", | |
["Year", "Citations", "Source", "Type", "Access", "Research Areas"], | |
index=0 | |
) | |
# Sort summaries based on selection | |
summaries = st.session_state['summaries'] | |
if sort_by == "Year": | |
summaries.sort(key=lambda x: x['year'], reverse=True) | |
elif sort_by == "Citations": | |
summaries.sort(key=lambda x: x['times_cited'], reverse=True) | |
elif sort_by == "Source": | |
summaries.sort(key=lambda x: x['source']) | |
elif sort_by == "Type": | |
summaries.sort(key=lambda x: x['document_type']) | |
elif sort_by == "Access": | |
summaries.sort(key=lambda x: x['open_access']) | |
elif sort_by == "Research Areas": | |
summaries.sort(key=lambda x: x['research_areas']) | |
# Display summaries in expandable sections | |
for paper in summaries: | |
with st.expander(f"{paper['title']} ({paper['year']})"): | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.write("**Summary:**") | |
st.write(paper['summary']) | |
with col2: | |
st.write(f"**Authors:** {paper['authors']}") | |
st.write(f"**Source:** {paper['source']}") | |
st.write(f"**DOI:** {paper['doi']}") | |
st.write(f"**Document Type:** {paper['document_type']}") | |
st.write(f"**Times Cited:** {paper['times_cited']}") | |
st.write(f"**Open Access:** {paper['open_access']}") | |
st.write(f"**Research Areas:** {paper['research_areas']}") |