summarizer / app.py
pendar02's picture
Create app.py
a964db1 verified
raw
history blame
7.67 kB
import streamlit as st
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, PeftConfig, LoraConfig
import torch
class Summarizer:
def __init__(self):
# Base model for individual summaries
self.base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
self.tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
try:
# Create LoRA config
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="SEQ_2_SEQ_LM",
target_modules=["q_proj", "v_proj"],
inference_mode=True
)
# Load base model for fine-tuning
base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
# Load the PEFT model from local path in the Space
self.finetuned_model = PeftModel.from_pretrained(
base_model_for_finetuned,
"models", # Points to the models directory in the Space
config=lora_config
)
self.finetuned_model.eval()
except Exception as e:
st.error(f"Error loading fine-tuned model: {str(e)}")
raise
def summarize_text(self, text, max_length=150, use_finetuned=False):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
if use_finetuned:
generation_kwargs = {
"max_length": max_length,
"num_beams": 4,
"length_penalty": 2.0,
"early_stopping": True
}
summary_ids = self.finetuned_model.generate(**inputs, **generation_kwargs)
else:
summary_ids = self.base_model.generate(
inputs["input_ids"],
max_length=max_length,
num_beams=4,
length_penalty=2.0,
early_stopping=True
)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def process_excel(self, file, question):
# Read Excel file
df = pd.read_excel(file)
# Extract abstracts and generate individual summaries using base model
summaries = []
with st.progress(0.0):
for idx, row in df.iterrows():
if pd.notna(row['Abstract']):
# Update progress bar
progress = (idx + 1) / len(df)
st.progress(progress)
paper_info = {
'title': row['Article Title'],
'authors': row['Authors'] if pd.notna(row['Authors']) else '',
'source': row['Source Title'] if pd.notna(row['Source Title']) else '',
'year': row['Publication Year'] if pd.notna(row['Publication Year']) else '',
'doi': row['DOI'] if pd.notna(row['DOI']) else '',
'document_type': row['Document Type'] if pd.notna(row['Document Type']) else '',
'times_cited': row['Times Cited, WoS Core'] if pd.notna(row['Times Cited, WoS Core']) else 0,
'open_access': row['Open Access Designations'] if pd.notna(row['Open Access Designations']) else '',
'research_areas': row['Research Areas'] if pd.notna(row['Research Areas']) else '',
'summary': self.summarize_text(row['Abstract'], use_finetuned=False)
}
summaries.append(paper_info)
# Generate overall summary using fine-tuned model
combined_summaries = " ".join([s['summary'] for s in summaries])
overall_summary = self.summarize_text(combined_summaries, max_length=250, use_finetuned=True)
return summaries, overall_summary
# Set page config
st.set_page_config(page_title="Research Paper Summarizer", layout="wide")
# Initialize session state
if 'summarizer' not in st.session_state:
st.session_state['summarizer'] = None
if 'summaries' not in st.session_state:
st.session_state['summaries'] = None
if 'overall_summary' not in st.session_state:
st.session_state['overall_summary'] = None
# App title and description
st.title("Research Paper Summarizer")
st.write("Upload an Excel file with research papers to generate summaries")
# Sidebar for inputs
with st.sidebar:
st.header("Input Options")
uploaded_file = st.file_uploader("Choose an Excel file", type=['xlsx', 'xls'])
question = st.text_area("Enter your research question")
# Generate button
generate_button = st.button("Generate Summaries", type="primary", use_container_width=True)
if generate_button and uploaded_file and question:
try:
# Initialize summarizer if not already done
if st.session_state['summarizer'] is None:
st.session_state['summarizer'] = Summarizer()
with st.spinner("Generating summaries..."):
summaries, overall_summary = st.session_state['summarizer'].process_excel(uploaded_file, question)
st.session_state['summaries'] = summaries
st.session_state['overall_summary'] = overall_summary
st.success("Summaries generated successfully!")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
elif generate_button:
st.warning("Please upload a file and enter a research question.")
# Main content area
if st.session_state['overall_summary']:
st.header("Overall Summary")
st.write(st.session_state['overall_summary'])
if st.session_state['summaries']:
st.header("Individual Paper Summaries")
# Sorting options
col1, col2 = st.columns([2, 3])
with col1:
sort_by = st.selectbox(
"Sort by",
["Year", "Citations", "Source", "Type", "Access", "Research Areas"],
index=0
)
# Sort summaries based on selection
summaries = st.session_state['summaries']
if sort_by == "Year":
summaries.sort(key=lambda x: x['year'], reverse=True)
elif sort_by == "Citations":
summaries.sort(key=lambda x: x['times_cited'], reverse=True)
elif sort_by == "Source":
summaries.sort(key=lambda x: x['source'])
elif sort_by == "Type":
summaries.sort(key=lambda x: x['document_type'])
elif sort_by == "Access":
summaries.sort(key=lambda x: x['open_access'])
elif sort_by == "Research Areas":
summaries.sort(key=lambda x: x['research_areas'])
# Display summaries in expandable sections
for paper in summaries:
with st.expander(f"{paper['title']} ({paper['year']})"):
col1, col2 = st.columns([2, 1])
with col1:
st.write("**Summary:**")
st.write(paper['summary'])
with col2:
st.write(f"**Authors:** {paper['authors']}")
st.write(f"**Source:** {paper['source']}")
st.write(f"**DOI:** {paper['doi']}")
st.write(f"**Document Type:** {paper['document_type']}")
st.write(f"**Times Cited:** {paper['times_cited']}")
st.write(f"**Open Access:** {paper['open_access']}")
st.write(f"**Research Areas:** {paper['research_areas']}")