Spaces:
Sleeping
Sleeping
File size: 7,666 Bytes
a964db1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import streamlit as st
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, PeftConfig, LoraConfig
import torch
class Summarizer:
def __init__(self):
# Base model for individual summaries
self.base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
self.tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
try:
# Create LoRA config
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type="SEQ_2_SEQ_LM",
target_modules=["q_proj", "v_proj"],
inference_mode=True
)
# Load base model for fine-tuning
base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
# Load the PEFT model from local path in the Space
self.finetuned_model = PeftModel.from_pretrained(
base_model_for_finetuned,
"models", # Points to the models directory in the Space
config=lora_config
)
self.finetuned_model.eval()
except Exception as e:
st.error(f"Error loading fine-tuned model: {str(e)}")
raise
def summarize_text(self, text, max_length=150, use_finetuned=False):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
if use_finetuned:
generation_kwargs = {
"max_length": max_length,
"num_beams": 4,
"length_penalty": 2.0,
"early_stopping": True
}
summary_ids = self.finetuned_model.generate(**inputs, **generation_kwargs)
else:
summary_ids = self.base_model.generate(
inputs["input_ids"],
max_length=max_length,
num_beams=4,
length_penalty=2.0,
early_stopping=True
)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def process_excel(self, file, question):
# Read Excel file
df = pd.read_excel(file)
# Extract abstracts and generate individual summaries using base model
summaries = []
with st.progress(0.0):
for idx, row in df.iterrows():
if pd.notna(row['Abstract']):
# Update progress bar
progress = (idx + 1) / len(df)
st.progress(progress)
paper_info = {
'title': row['Article Title'],
'authors': row['Authors'] if pd.notna(row['Authors']) else '',
'source': row['Source Title'] if pd.notna(row['Source Title']) else '',
'year': row['Publication Year'] if pd.notna(row['Publication Year']) else '',
'doi': row['DOI'] if pd.notna(row['DOI']) else '',
'document_type': row['Document Type'] if pd.notna(row['Document Type']) else '',
'times_cited': row['Times Cited, WoS Core'] if pd.notna(row['Times Cited, WoS Core']) else 0,
'open_access': row['Open Access Designations'] if pd.notna(row['Open Access Designations']) else '',
'research_areas': row['Research Areas'] if pd.notna(row['Research Areas']) else '',
'summary': self.summarize_text(row['Abstract'], use_finetuned=False)
}
summaries.append(paper_info)
# Generate overall summary using fine-tuned model
combined_summaries = " ".join([s['summary'] for s in summaries])
overall_summary = self.summarize_text(combined_summaries, max_length=250, use_finetuned=True)
return summaries, overall_summary
# Set page config
st.set_page_config(page_title="Research Paper Summarizer", layout="wide")
# Initialize session state
if 'summarizer' not in st.session_state:
st.session_state['summarizer'] = None
if 'summaries' not in st.session_state:
st.session_state['summaries'] = None
if 'overall_summary' not in st.session_state:
st.session_state['overall_summary'] = None
# App title and description
st.title("Research Paper Summarizer")
st.write("Upload an Excel file with research papers to generate summaries")
# Sidebar for inputs
with st.sidebar:
st.header("Input Options")
uploaded_file = st.file_uploader("Choose an Excel file", type=['xlsx', 'xls'])
question = st.text_area("Enter your research question")
# Generate button
generate_button = st.button("Generate Summaries", type="primary", use_container_width=True)
if generate_button and uploaded_file and question:
try:
# Initialize summarizer if not already done
if st.session_state['summarizer'] is None:
st.session_state['summarizer'] = Summarizer()
with st.spinner("Generating summaries..."):
summaries, overall_summary = st.session_state['summarizer'].process_excel(uploaded_file, question)
st.session_state['summaries'] = summaries
st.session_state['overall_summary'] = overall_summary
st.success("Summaries generated successfully!")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
elif generate_button:
st.warning("Please upload a file and enter a research question.")
# Main content area
if st.session_state['overall_summary']:
st.header("Overall Summary")
st.write(st.session_state['overall_summary'])
if st.session_state['summaries']:
st.header("Individual Paper Summaries")
# Sorting options
col1, col2 = st.columns([2, 3])
with col1:
sort_by = st.selectbox(
"Sort by",
["Year", "Citations", "Source", "Type", "Access", "Research Areas"],
index=0
)
# Sort summaries based on selection
summaries = st.session_state['summaries']
if sort_by == "Year":
summaries.sort(key=lambda x: x['year'], reverse=True)
elif sort_by == "Citations":
summaries.sort(key=lambda x: x['times_cited'], reverse=True)
elif sort_by == "Source":
summaries.sort(key=lambda x: x['source'])
elif sort_by == "Type":
summaries.sort(key=lambda x: x['document_type'])
elif sort_by == "Access":
summaries.sort(key=lambda x: x['open_access'])
elif sort_by == "Research Areas":
summaries.sort(key=lambda x: x['research_areas'])
# Display summaries in expandable sections
for paper in summaries:
with st.expander(f"{paper['title']} ({paper['year']})"):
col1, col2 = st.columns([2, 1])
with col1:
st.write("**Summary:**")
st.write(paper['summary'])
with col2:
st.write(f"**Authors:** {paper['authors']}")
st.write(f"**Source:** {paper['source']}")
st.write(f"**DOI:** {paper['doi']}")
st.write(f"**Document Type:** {paper['document_type']}")
st.write(f"**Times Cited:** {paper['times_cited']}")
st.write(f"**Open Access:** {paper['open_access']}")
st.write(f"**Research Areas:** {paper['research_areas']}") |