pendar02 commited on
Commit
0da1e60
1 Parent(s): f791a84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -47
app.py CHANGED
@@ -1,73 +1,193 @@
1
- # At the top of app.py, add debug printing
 
 
 
 
2
  import os
3
- import logging
4
 
5
- # Configure logging
6
- logging.basicConfig(level=logging.INFO)
7
- logger = logging.getLogger(__name__)
8
 
9
  class Summarizer:
10
  def __init__(self):
11
  try:
12
- # Print current directory contents for debugging
13
- st.write("Current directory contents:")
14
- st.write(os.listdir('.'))
15
 
16
- # Base model
17
- self.base_model = AutoModelForSeq2SeqLM.from_pretrained(
18
- "GanjinZero/biobart-base",
19
- local_files_only=False # Allow downloading base model
20
- )
21
- self.tokenizer = AutoTokenizer.from_pretrained(
22
- "GanjinZero/biobart-base",
23
- local_files_only=False
24
- )
25
 
26
- # Load adapter config from local file
27
- adapter_config_path = "./adapter_config.json"
28
- if not os.path.exists(adapter_config_path):
29
- st.error(f"adapter_config.json not found in {os.getcwd()}")
30
- raise FileNotFoundError("adapter_config.json not found")
31
-
32
- st.write(f"Loading adapter config from {adapter_config_path}")
33
 
34
- # Create LoRA config
 
 
 
35
  lora_config = LoraConfig(
36
  r=8,
37
  lora_alpha=16,
38
- lora_dropout=0.1,
39
- bias="none",
40
- task_type="SEQ_2_SEQ_LM",
41
  target_modules=["q_proj", "v_proj"],
42
  inference_mode=True
43
  )
44
 
45
- # Load base model for fine-tuning
46
- base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained(
47
- "GanjinZero/biobart-base",
48
- local_files_only=False
49
- )
50
-
51
- st.write("Loading fine-tuned model...")
52
- # Try to load the PEFT model from the current directory
53
  self.finetuned_model = PeftModel.from_pretrained(
54
  base_model_for_finetuned,
55
- ".", # Current directory
56
  config=lora_config,
57
- torch_dtype=torch.float32,
58
- is_trainable=False,
59
- local_files_only=True
60
  )
61
-
62
  self.finetuned_model.eval()
 
63
  st.success("Models loaded successfully!")
64
 
65
  except Exception as e:
66
  st.error(f"Error loading models: {str(e)}")
67
  st.write("Debug info:")
68
- st.write(f"Current working directory: {os.getcwd()}")
69
- st.write(f"Directory contents: {os.listdir('.')}")
70
- if os.path.exists('adapter_config.json'):
71
- with open('adapter_config.json', 'r') as f:
72
- st.write("adapter_config.json contents:", f.read())
73
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ from peft import PeftModel, LoraConfig
5
+ import torch
6
  import os
 
7
 
8
+ # Configure logging and page
9
+ st.set_page_config(page_title="Research Paper Summarizer", layout="wide")
 
10
 
11
  class Summarizer:
12
  def __init__(self):
13
  try:
14
+ st.info("Loading models... Please wait.")
 
 
15
 
16
+ # Load base model and tokenizer
17
+ self.base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
18
+ self.tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
 
 
 
 
 
 
19
 
20
+ # Debug info
21
+ st.write("Current directory contents:", os.listdir())
 
 
 
 
 
22
 
23
+ # Load fine-tuned model
24
+ base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
25
+
26
+ # Configure LoRA
27
  lora_config = LoraConfig(
28
  r=8,
29
  lora_alpha=16,
 
 
 
30
  target_modules=["q_proj", "v_proj"],
31
  inference_mode=True
32
  )
33
 
34
+ # Load PEFT model
 
 
 
 
 
 
 
35
  self.finetuned_model = PeftModel.from_pretrained(
36
  base_model_for_finetuned,
37
+ ".",
38
  config=lora_config,
39
+ is_trainable=False
 
 
40
  )
 
41
  self.finetuned_model.eval()
42
+
43
  st.success("Models loaded successfully!")
44
 
45
  except Exception as e:
46
  st.error(f"Error loading models: {str(e)}")
47
  st.write("Debug info:")
48
+ st.write(f"Working directory: {os.getcwd()}")
49
+ st.write(f"Files available: {os.listdir()}")
50
+ raise
51
+
52
+ def summarize_text(self, text, max_length=150, use_finetuned=False):
53
+ try:
54
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
55
+
56
+ if use_finetuned:
57
+ summary_ids = self.finetuned_model.generate(
58
+ **inputs,
59
+ max_length=max_length,
60
+ num_beams=4,
61
+ length_penalty=2.0,
62
+ early_stopping=True
63
+ )
64
+ else:
65
+ summary_ids = self.base_model.generate(
66
+ inputs["input_ids"],
67
+ max_length=max_length,
68
+ num_beams=4,
69
+ length_penalty=2.0,
70
+ early_stopping=True
71
+ )
72
+
73
+ return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
74
+ except Exception as e:
75
+ st.error(f"Error in summarization: {str(e)}")
76
+ return "Error generating summary"
77
+
78
+ def process_excel(self, file, question):
79
+ try:
80
+ df = pd.read_excel(file)
81
+ summaries = []
82
+ progress_bar = st.progress(0)
83
+
84
+ total_rows = len(df)
85
+ for idx, row in df.iterrows():
86
+ if pd.notna(row['Abstract']):
87
+ progress_bar.progress((idx + 1) / total_rows)
88
+
89
+ paper_info = {
90
+ 'title': row['Article Title'],
91
+ 'authors': row['Authors'] if pd.notna(row['Authors']) else '',
92
+ 'source': row['Source Title'] if pd.notna(row['Source Title']) else '',
93
+ 'year': row['Publication Year'] if pd.notna(row['Publication Year']) else '',
94
+ 'doi': row['DOI'] if pd.notna(row['DOI']) else '',
95
+ 'document_type': row['Document Type'] if pd.notna(row['Document Type']) else '',
96
+ 'times_cited': row['Times Cited, WoS Core'] if pd.notna(row['Times Cited, WoS Core']) else 0,
97
+ 'open_access': row['Open Access Designations'] if pd.notna(row['Open Access Designations']) else '',
98
+ 'research_areas': row['Research Areas'] if pd.notna(row['Research Areas']) else '',
99
+ 'summary': self.summarize_text(row['Abstract'], use_finetuned=False)
100
+ }
101
+ summaries.append(paper_info)
102
+
103
+ # Generate overall summary
104
+ combined_summaries = " ".join([s['summary'] for s in summaries])
105
+ overall_summary = self.summarize_text(combined_summaries, max_length=250, use_finetuned=True)
106
+
107
+ return summaries, overall_summary
108
+ except Exception as e:
109
+ st.error(f"Error processing Excel file: {str(e)}")
110
+ return [], "Error generating summary"
111
+
112
+ # Initialize session state
113
+ if 'summarizer' not in st.session_state:
114
+ st.session_state['summarizer'] = None
115
+ if 'summaries' not in st.session_state:
116
+ st.session_state['summaries'] = None
117
+ if 'overall_summary' not in st.session_state:
118
+ st.session_state['overall_summary'] = None
119
+
120
+ # App header
121
+ st.title("Research Paper Summarizer")
122
+ st.write("Upload an Excel file with research papers to generate summaries")
123
+
124
+ # Sidebar inputs
125
+ with st.sidebar:
126
+ st.header("Input Options")
127
+ uploaded_file = st.file_uploader("Choose an Excel file", type=['xlsx', 'xls'])
128
+ question = st.text_area("Enter your research question")
129
+
130
+ generate_button = st.button("Generate Summaries", type="primary", use_container_width=True)
131
+
132
+ # Main processing
133
+ if generate_button and uploaded_file and question:
134
+ try:
135
+ if st.session_state['summarizer'] is None:
136
+ st.session_state['summarizer'] = Summarizer()
137
+
138
+ with st.spinner("Processing papers..."):
139
+ summaries, overall_summary = st.session_state['summarizer'].process_excel(uploaded_file, question)
140
+ st.session_state['summaries'] = summaries
141
+ st.session_state['overall_summary'] = overall_summary
142
+ except Exception as e:
143
+ st.error(f"An error occurred: {str(e)}")
144
+ elif generate_button:
145
+ st.warning("Please upload a file and enter a research question.")
146
+
147
+ # Display results
148
+ if st.session_state['overall_summary']:
149
+ st.header("Overall Summary")
150
+ st.write(st.session_state['overall_summary'])
151
+
152
+ if st.session_state['summaries']:
153
+ st.header("Individual Paper Summaries")
154
+
155
+ # Sorting options
156
+ col1, _ = st.columns([2, 3])
157
+ with col1:
158
+ sort_by = st.selectbox(
159
+ "Sort by",
160
+ ["Year", "Citations", "Source", "Type", "Access", "Research Areas"],
161
+ index=0
162
+ )
163
+
164
+ # Sort summaries
165
+ summaries = st.session_state['summaries']
166
+ if sort_by == "Year":
167
+ summaries.sort(key=lambda x: str(x['year']), reverse=True)
168
+ elif sort_by == "Citations":
169
+ summaries.sort(key=lambda x: x['times_cited'], reverse=True)
170
+ elif sort_by == "Source":
171
+ summaries.sort(key=lambda x: str(x['source']))
172
+ elif sort_by == "Type":
173
+ summaries.sort(key=lambda x: str(x['document_type']))
174
+ elif sort_by == "Access":
175
+ summaries.sort(key=lambda x: str(x['open_access']))
176
+ elif sort_by == "Research Areas":
177
+ summaries.sort(key=lambda x: str(x['research_areas']))
178
+
179
+ # Display summaries
180
+ for paper in summaries:
181
+ with st.expander(f"{paper['title']} ({paper['year']})"):
182
+ col1, col2 = st.columns([2, 1])
183
+ with col1:
184
+ st.write("**Summary:**")
185
+ st.write(paper['summary'])
186
+ with col2:
187
+ st.write(f"**Authors:** {paper['authors']}")
188
+ st.write(f"**Source:** {paper['source']}")
189
+ st.write(f"**DOI:** {paper['doi']}")
190
+ st.write(f"**Document Type:** {paper['document_type']}")
191
+ st.write(f"**Times Cited:** {paper['times_cited']}")
192
+ st.write(f"**Open Access:** {paper['open_access']}")
193
+ st.write(f"**Research Areas:** {paper['research_areas']}")