File size: 7,666 Bytes
a964db1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import streamlit as st
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import PeftModel, PeftConfig, LoraConfig
import torch

class Summarizer:
    def __init__(self):
        # Base model for individual summaries
        self.base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
        self.tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
        
        try:
            # Create LoRA config
            lora_config = LoraConfig(
                r=8,
                lora_alpha=16,
                lora_dropout=0.1,
                bias="none",
                task_type="SEQ_2_SEQ_LM",
                target_modules=["q_proj", "v_proj"],
                inference_mode=True
            )
            
            # Load base model for fine-tuning
            base_model_for_finetuned = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
            
            # Load the PEFT model from local path in the Space
            self.finetuned_model = PeftModel.from_pretrained(
                base_model_for_finetuned, 
                "models",  # Points to the models directory in the Space
                config=lora_config
            )
            
            self.finetuned_model.eval()
            
        except Exception as e:
            st.error(f"Error loading fine-tuned model: {str(e)}")
            raise

    def summarize_text(self, text, max_length=150, use_finetuned=False):
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        
        if use_finetuned:
            generation_kwargs = {
                "max_length": max_length,
                "num_beams": 4,
                "length_penalty": 2.0,
                "early_stopping": True
            }
            summary_ids = self.finetuned_model.generate(**inputs, **generation_kwargs)
        else:
            summary_ids = self.base_model.generate(
                inputs["input_ids"],
                max_length=max_length,
                num_beams=4,
                length_penalty=2.0,
                early_stopping=True
            )
            
        return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    def process_excel(self, file, question):
        # Read Excel file
        df = pd.read_excel(file)
    
        # Extract abstracts and generate individual summaries using base model
        summaries = []
        with st.progress(0.0):
            for idx, row in df.iterrows():
                if pd.notna(row['Abstract']):
                    # Update progress bar
                    progress = (idx + 1) / len(df)
                    st.progress(progress)
                    
                    paper_info = {
                        'title': row['Article Title'],
                        'authors': row['Authors'] if pd.notna(row['Authors']) else '',
                        'source': row['Source Title'] if pd.notna(row['Source Title']) else '',
                        'year': row['Publication Year'] if pd.notna(row['Publication Year']) else '',
                        'doi': row['DOI'] if pd.notna(row['DOI']) else '',
                        'document_type': row['Document Type'] if pd.notna(row['Document Type']) else '',
                        'times_cited': row['Times Cited, WoS Core'] if pd.notna(row['Times Cited, WoS Core']) else 0,
                        'open_access': row['Open Access Designations'] if pd.notna(row['Open Access Designations']) else '',
                        'research_areas': row['Research Areas'] if pd.notna(row['Research Areas']) else '',
                        'summary': self.summarize_text(row['Abstract'], use_finetuned=False)
                    }
                    summaries.append(paper_info)
        
        # Generate overall summary using fine-tuned model
        combined_summaries = " ".join([s['summary'] for s in summaries])
        overall_summary = self.summarize_text(combined_summaries, max_length=250, use_finetuned=True)
        
        return summaries, overall_summary

# Set page config
st.set_page_config(page_title="Research Paper Summarizer", layout="wide")

# Initialize session state
if 'summarizer' not in st.session_state:
    st.session_state['summarizer'] = None
if 'summaries' not in st.session_state:
    st.session_state['summaries'] = None
if 'overall_summary' not in st.session_state:
    st.session_state['overall_summary'] = None

# App title and description
st.title("Research Paper Summarizer")
st.write("Upload an Excel file with research papers to generate summaries")

# Sidebar for inputs
with st.sidebar:
    st.header("Input Options")
    uploaded_file = st.file_uploader("Choose an Excel file", type=['xlsx', 'xls'])
    question = st.text_area("Enter your research question")
    
    # Generate button
    generate_button = st.button("Generate Summaries", type="primary", use_container_width=True)
    
    if generate_button and uploaded_file and question:
        try:
            # Initialize summarizer if not already done
            if st.session_state['summarizer'] is None:
                st.session_state['summarizer'] = Summarizer()
            
            with st.spinner("Generating summaries..."):
                summaries, overall_summary = st.session_state['summarizer'].process_excel(uploaded_file, question)
                st.session_state['summaries'] = summaries
                st.session_state['overall_summary'] = overall_summary
                st.success("Summaries generated successfully!")
                
        except Exception as e:
            st.error(f"An error occurred: {str(e)}")
    elif generate_button:
        st.warning("Please upload a file and enter a research question.")

# Main content area
if st.session_state['overall_summary']:
    st.header("Overall Summary")
    st.write(st.session_state['overall_summary'])

if st.session_state['summaries']:
    st.header("Individual Paper Summaries")
    
    # Sorting options
    col1, col2 = st.columns([2, 3])
    with col1:
        sort_by = st.selectbox(
            "Sort by",
            ["Year", "Citations", "Source", "Type", "Access", "Research Areas"],
            index=0
        )
    
    # Sort summaries based on selection
    summaries = st.session_state['summaries']
    if sort_by == "Year":
        summaries.sort(key=lambda x: x['year'], reverse=True)
    elif sort_by == "Citations":
        summaries.sort(key=lambda x: x['times_cited'], reverse=True)
    elif sort_by == "Source":
        summaries.sort(key=lambda x: x['source'])
    elif sort_by == "Type":
        summaries.sort(key=lambda x: x['document_type'])
    elif sort_by == "Access":
        summaries.sort(key=lambda x: x['open_access'])
    elif sort_by == "Research Areas":
        summaries.sort(key=lambda x: x['research_areas'])

    # Display summaries in expandable sections
    for paper in summaries:
        with st.expander(f"{paper['title']} ({paper['year']})"):
            col1, col2 = st.columns([2, 1])
            with col1:
                st.write("**Summary:**")
                st.write(paper['summary'])
            with col2:
                st.write(f"**Authors:** {paper['authors']}")
                st.write(f"**Source:** {paper['source']}")
                st.write(f"**DOI:** {paper['doi']}")
                st.write(f"**Document Type:** {paper['document_type']}")
                st.write(f"**Times Cited:** {paper['times_cited']}")
                st.write(f"**Open Access:** {paper['open_access']}")
                st.write(f"**Research Areas:** {paper['research_areas']}")