Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,73 +1,193 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
-
import logging
|
4 |
|
5 |
-
# Configure logging
|
6 |
-
|
7 |
-
logger = logging.getLogger(__name__)
|
8 |
|
9 |
class Summarizer:
|
10 |
def __init__(self):
|
11 |
try:
|
12 |
-
|
13 |
-
st.write("Current directory contents:")
|
14 |
-
st.write(os.listdir('.'))
|
15 |
|
16 |
-
#
|
17 |
-
self.base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
18 |
-
|
19 |
-
local_files_only=False # Allow downloading base model
|
20 |
-
)
|
21 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
22 |
-
"GanjinZero/biobart-base",
|
23 |
-
local_files_only=False
|
24 |
-
)
|
25 |
|
26 |
-
#
|
27 |
-
|
28 |
-
if not os.path.exists(adapter_config_path):
|
29 |
-
st.error(f"adapter_config.json not found in {os.getcwd()}")
|
30 |
-
raise FileNotFoundError("adapter_config.json not found")
|
31 |
-
|
32 |
-
st.write(f"Loading adapter config from {adapter_config_path}")
|
33 |
|
34 |
-
#
|
|
|
|
|
|
|
35 |
lora_config = LoraConfig(
|
36 |
r=8,
|
37 |
lora_alpha=16,
|
38 |
-
lora_dropout=0.1,
|
39 |
-
bias="none",
|
40 |
-
task_type="SEQ_2_SEQ_LM",
|
41 |
target_modules=["q_proj", "v_proj"],
|
42 |
inference_mode=True
|
43 |
)
|
44 |
|
45 |
-
# Load
|
46 |
-
base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained(
|
47 |
-
"GanjinZero/biobart-base",
|
48 |
-
local_files_only=False
|
49 |
-
)
|
50 |
-
|
51 |
-
st.write("Loading fine-tuned model...")
|
52 |
-
# Try to load the PEFT model from the current directory
|
53 |
self.finetuned_model = PeftModel.from_pretrained(
|
54 |
base_model_for_finetuned,
|
55 |
-
".",
|
56 |
config=lora_config,
|
57 |
-
|
58 |
-
is_trainable=False,
|
59 |
-
local_files_only=True
|
60 |
)
|
61 |
-
|
62 |
self.finetuned_model.eval()
|
|
|
63 |
st.success("Models loaded successfully!")
|
64 |
|
65 |
except Exception as e:
|
66 |
st.error(f"Error loading models: {str(e)}")
|
67 |
st.write("Debug info:")
|
68 |
-
st.write(f"
|
69 |
-
st.write(f"
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
4 |
+
from peft import PeftModel, LoraConfig
|
5 |
+
import torch
|
6 |
import os
|
|
|
7 |
|
8 |
+
# Configure logging and page
|
9 |
+
st.set_page_config(page_title="Research Paper Summarizer", layout="wide")
|
|
|
10 |
|
11 |
class Summarizer:
|
12 |
def __init__(self):
|
13 |
try:
|
14 |
+
st.info("Loading models... Please wait.")
|
|
|
|
|
15 |
|
16 |
+
# Load base model and tokenizer
|
17 |
+
self.base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
|
18 |
+
self.tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
+
# Debug info
|
21 |
+
st.write("Current directory contents:", os.listdir())
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Load fine-tuned model
|
24 |
+
base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
|
25 |
+
|
26 |
+
# Configure LoRA
|
27 |
lora_config = LoraConfig(
|
28 |
r=8,
|
29 |
lora_alpha=16,
|
|
|
|
|
|
|
30 |
target_modules=["q_proj", "v_proj"],
|
31 |
inference_mode=True
|
32 |
)
|
33 |
|
34 |
+
# Load PEFT model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
self.finetuned_model = PeftModel.from_pretrained(
|
36 |
base_model_for_finetuned,
|
37 |
+
".",
|
38 |
config=lora_config,
|
39 |
+
is_trainable=False
|
|
|
|
|
40 |
)
|
|
|
41 |
self.finetuned_model.eval()
|
42 |
+
|
43 |
st.success("Models loaded successfully!")
|
44 |
|
45 |
except Exception as e:
|
46 |
st.error(f"Error loading models: {str(e)}")
|
47 |
st.write("Debug info:")
|
48 |
+
st.write(f"Working directory: {os.getcwd()}")
|
49 |
+
st.write(f"Files available: {os.listdir()}")
|
50 |
+
raise
|
51 |
+
|
52 |
+
def summarize_text(self, text, max_length=150, use_finetuned=False):
|
53 |
+
try:
|
54 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
55 |
+
|
56 |
+
if use_finetuned:
|
57 |
+
summary_ids = self.finetuned_model.generate(
|
58 |
+
**inputs,
|
59 |
+
max_length=max_length,
|
60 |
+
num_beams=4,
|
61 |
+
length_penalty=2.0,
|
62 |
+
early_stopping=True
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
summary_ids = self.base_model.generate(
|
66 |
+
inputs["input_ids"],
|
67 |
+
max_length=max_length,
|
68 |
+
num_beams=4,
|
69 |
+
length_penalty=2.0,
|
70 |
+
early_stopping=True
|
71 |
+
)
|
72 |
+
|
73 |
+
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
74 |
+
except Exception as e:
|
75 |
+
st.error(f"Error in summarization: {str(e)}")
|
76 |
+
return "Error generating summary"
|
77 |
+
|
78 |
+
def process_excel(self, file, question):
|
79 |
+
try:
|
80 |
+
df = pd.read_excel(file)
|
81 |
+
summaries = []
|
82 |
+
progress_bar = st.progress(0)
|
83 |
+
|
84 |
+
total_rows = len(df)
|
85 |
+
for idx, row in df.iterrows():
|
86 |
+
if pd.notna(row['Abstract']):
|
87 |
+
progress_bar.progress((idx + 1) / total_rows)
|
88 |
+
|
89 |
+
paper_info = {
|
90 |
+
'title': row['Article Title'],
|
91 |
+
'authors': row['Authors'] if pd.notna(row['Authors']) else '',
|
92 |
+
'source': row['Source Title'] if pd.notna(row['Source Title']) else '',
|
93 |
+
'year': row['Publication Year'] if pd.notna(row['Publication Year']) else '',
|
94 |
+
'doi': row['DOI'] if pd.notna(row['DOI']) else '',
|
95 |
+
'document_type': row['Document Type'] if pd.notna(row['Document Type']) else '',
|
96 |
+
'times_cited': row['Times Cited, WoS Core'] if pd.notna(row['Times Cited, WoS Core']) else 0,
|
97 |
+
'open_access': row['Open Access Designations'] if pd.notna(row['Open Access Designations']) else '',
|
98 |
+
'research_areas': row['Research Areas'] if pd.notna(row['Research Areas']) else '',
|
99 |
+
'summary': self.summarize_text(row['Abstract'], use_finetuned=False)
|
100 |
+
}
|
101 |
+
summaries.append(paper_info)
|
102 |
+
|
103 |
+
# Generate overall summary
|
104 |
+
combined_summaries = " ".join([s['summary'] for s in summaries])
|
105 |
+
overall_summary = self.summarize_text(combined_summaries, max_length=250, use_finetuned=True)
|
106 |
+
|
107 |
+
return summaries, overall_summary
|
108 |
+
except Exception as e:
|
109 |
+
st.error(f"Error processing Excel file: {str(e)}")
|
110 |
+
return [], "Error generating summary"
|
111 |
+
|
112 |
+
# Initialize session state
|
113 |
+
if 'summarizer' not in st.session_state:
|
114 |
+
st.session_state['summarizer'] = None
|
115 |
+
if 'summaries' not in st.session_state:
|
116 |
+
st.session_state['summaries'] = None
|
117 |
+
if 'overall_summary' not in st.session_state:
|
118 |
+
st.session_state['overall_summary'] = None
|
119 |
+
|
120 |
+
# App header
|
121 |
+
st.title("Research Paper Summarizer")
|
122 |
+
st.write("Upload an Excel file with research papers to generate summaries")
|
123 |
+
|
124 |
+
# Sidebar inputs
|
125 |
+
with st.sidebar:
|
126 |
+
st.header("Input Options")
|
127 |
+
uploaded_file = st.file_uploader("Choose an Excel file", type=['xlsx', 'xls'])
|
128 |
+
question = st.text_area("Enter your research question")
|
129 |
+
|
130 |
+
generate_button = st.button("Generate Summaries", type="primary", use_container_width=True)
|
131 |
+
|
132 |
+
# Main processing
|
133 |
+
if generate_button and uploaded_file and question:
|
134 |
+
try:
|
135 |
+
if st.session_state['summarizer'] is None:
|
136 |
+
st.session_state['summarizer'] = Summarizer()
|
137 |
+
|
138 |
+
with st.spinner("Processing papers..."):
|
139 |
+
summaries, overall_summary = st.session_state['summarizer'].process_excel(uploaded_file, question)
|
140 |
+
st.session_state['summaries'] = summaries
|
141 |
+
st.session_state['overall_summary'] = overall_summary
|
142 |
+
except Exception as e:
|
143 |
+
st.error(f"An error occurred: {str(e)}")
|
144 |
+
elif generate_button:
|
145 |
+
st.warning("Please upload a file and enter a research question.")
|
146 |
+
|
147 |
+
# Display results
|
148 |
+
if st.session_state['overall_summary']:
|
149 |
+
st.header("Overall Summary")
|
150 |
+
st.write(st.session_state['overall_summary'])
|
151 |
+
|
152 |
+
if st.session_state['summaries']:
|
153 |
+
st.header("Individual Paper Summaries")
|
154 |
+
|
155 |
+
# Sorting options
|
156 |
+
col1, _ = st.columns([2, 3])
|
157 |
+
with col1:
|
158 |
+
sort_by = st.selectbox(
|
159 |
+
"Sort by",
|
160 |
+
["Year", "Citations", "Source", "Type", "Access", "Research Areas"],
|
161 |
+
index=0
|
162 |
+
)
|
163 |
+
|
164 |
+
# Sort summaries
|
165 |
+
summaries = st.session_state['summaries']
|
166 |
+
if sort_by == "Year":
|
167 |
+
summaries.sort(key=lambda x: str(x['year']), reverse=True)
|
168 |
+
elif sort_by == "Citations":
|
169 |
+
summaries.sort(key=lambda x: x['times_cited'], reverse=True)
|
170 |
+
elif sort_by == "Source":
|
171 |
+
summaries.sort(key=lambda x: str(x['source']))
|
172 |
+
elif sort_by == "Type":
|
173 |
+
summaries.sort(key=lambda x: str(x['document_type']))
|
174 |
+
elif sort_by == "Access":
|
175 |
+
summaries.sort(key=lambda x: str(x['open_access']))
|
176 |
+
elif sort_by == "Research Areas":
|
177 |
+
summaries.sort(key=lambda x: str(x['research_areas']))
|
178 |
+
|
179 |
+
# Display summaries
|
180 |
+
for paper in summaries:
|
181 |
+
with st.expander(f"{paper['title']} ({paper['year']})"):
|
182 |
+
col1, col2 = st.columns([2, 1])
|
183 |
+
with col1:
|
184 |
+
st.write("**Summary:**")
|
185 |
+
st.write(paper['summary'])
|
186 |
+
with col2:
|
187 |
+
st.write(f"**Authors:** {paper['authors']}")
|
188 |
+
st.write(f"**Source:** {paper['source']}")
|
189 |
+
st.write(f"**DOI:** {paper['doi']}")
|
190 |
+
st.write(f"**Document Type:** {paper['document_type']}")
|
191 |
+
st.write(f"**Times Cited:** {paper['times_cited']}")
|
192 |
+
st.write(f"**Open Access:** {paper['open_access']}")
|
193 |
+
st.write(f"**Research Areas:** {paper['research_areas']}")
|