File size: 7,165 Bytes
89cbacd
06e1bb5
426ec6e
89cbacd
 
 
 
 
 
 
 
 
 
 
06e1bb5
89cbacd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e1bb5
89cbacd
06e1bb5
 
89cbacd
 
 
 
 
 
 
 
06e1bb5
 
89cbacd
 
 
 
 
 
 
 
06e1bb5
89cbacd
 
 
 
 
 
426ec6e
 
06e1bb5
426ec6e
 
 
89cbacd
06e1bb5
426ec6e
 
 
06e1bb5
 
 
 
 
 
89cbacd
426ec6e
 
06e1bb5
 
 
 
 
 
426ec6e
06e1bb5
 
 
 
89cbacd
 
06e1bb5
 
89cbacd
06e1bb5
 
 
89cbacd
 
06e1bb5
 
89cbacd
06e1bb5
 
426ec6e
06e1bb5
 
89cbacd
 
 
 
 
 
 
426ec6e
89cbacd
 
06e1bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89cbacd
426ec6e
06e1bb5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import pandas as pd
from transformers import BartTokenizer, BartForConditionalGeneration, pipeline, T5Tokenizer, T5ForConditionalGeneration

# Initialize BART and T5 models and tokenizers for summarization
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
bart_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
t5_tokenizer = T5Tokenizer.from_pretrained('t5-small')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small')

# Initialize Summarization pipeline for BART
summarizer = pipeline("summarization", model=bart_model, tokenizer=bart_tokenizer)

# Healthcare and AI keyword lists
healthcare_keywords = ["disease", "cancer", "patient", "treatment", "health", "illness", "medicine", "symptom", "diagnosis", "epidemic", "infection"]
ai_keywords = ["algorithm", "artificial intelligence", "machine learning", "neural network", "AI", "model", "deep learning", "prediction", "data"]

# Function to classify the domain (Healthcare, AI, or both)
def classify_domain(title, abstract):
    healthcare_detected = any(keyword.lower() in (title + abstract).lower() for keyword in healthcare_keywords)
    ai_detected = any(keyword.lower() in (title + abstract).lower() for keyword in ai_keywords)

    if healthcare_detected and ai_detected:
        return "Healthcare, AI"  # Both healthcare and AI
    elif healthcare_detected:
        return "HealthCare"
    elif ai_detected:
        return "AI"
    return "General"

# Function to generate extractive summaries using BART
def extractive_summary(text):
    summary = summarizer(text, max_length=150, min_length=50, do_sample=False)
    return summary[0]['summary_text']

# Healthcare Agent to enhance healthcare-related content (focusing on diseases and treatments)
def healthcare_agent(abstract):
    # Check if healthcare-related keywords are in the abstract
    healthcare_relevant_text = " ".join([sentence for sentence in abstract.split('.') if any(keyword in sentence.lower() for keyword in healthcare_keywords)])

    # If healthcare-related sentences are found, generate a summary
    if healthcare_relevant_text:
        healthcare_summary = extractive_summary(healthcare_relevant_text)
        return healthcare_summary
    else:
        return "Not related to Healthcare"

# AI Agent to enhance AI-related content (focusing on algorithms and machine learning)
def ai_agent(abstract):
    # Check if AI-related keywords are in the abstract
    ai_relevant_text = " ".join([sentence for sentence in abstract.split('.') if any(keyword in sentence.lower() for keyword in ai_keywords)]).strip()
    if ai_relevant_text:
        ai_summary = extractive_summary(ai_relevant_text)
        return ai_summary
    else:
        return "Not related to AI"

# Function to generate general summary which is paraphrased (using T5 model for rephrasing)
def generate_general_summary(abstract):
    # Use T5 model to paraphrase the abstract and generate a general summary
    input_text = f"summarize: {abstract}"
    input_ids = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
    summary_ids = t5_model.generate(input_ids, max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True)
    summary = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

# Function to generate collaborative insights between healthcare and AI
def generate_collaborative_insights(abstract, title, domain):
    # Initialize summary placeholders
    general_summary = generate_general_summary(abstract)
    healthcare_summary = healthcare_agent(abstract)
    ai_summary = ai_agent(abstract)

    # Collaborative summary if both healthcare and AI are involved
    if domain == "Healthcare, AI":
        collaborative_summary = f"Collaborative Insights between Healthcare and AI: {healthcare_summary} {ai_summary}"
    else:
        collaborative_summary = "Not related to both Healthcare and AI"  # Collaborative insight will not be generated if the domain does not match

    # If domain doesn't match healthcare or AI, use general summary and not related for the respective fields
    if domain == "General":
        healthcare_summary = "Not related to Healthcare"
        ai_summary = "Not related to AI"

    return general_summary, healthcare_summary, ai_summary, collaborative_summary

# Function to process a single abstract
def process_single_abstract(title, abstract):
    domain = classify_domain(title, abstract)
    general_summary, healthcare_summary, ai_summary, collaborative_summary = generate_collaborative_insights(abstract, title, domain)
    return general_summary, healthcare_summary, ai_summary, collaborative_summary

# Function to process a CSV file
def process_csv(file):
    df = pd.read_csv(file.name)

    # Ensure the required columns are present
    if 'Title' not in df.columns or 'Abstract' not in df.columns:
        return "CSV file must contain 'Title' and 'Abstract' columns."

    # Prepare a list to store results
    results = []

    # Process each row in the CSV file
    for index, row in df.iterrows():
        title = row['Title']
        abstract = row['Abstract']

        # Classify the domain (Healthcare, AI, or both)
        domain = classify_domain(title, abstract)

        # Generate summaries
        general_summary, healthcare_summary, ai_summary, collaborative_summary = generate_collaborative_insights(abstract, title, domain)

        # Store the results
        results.append({
            'Title': title,
            'Abstract': abstract,
            'Domain': domain,
            'General Summary': general_summary,
            'HealthCare Summary': healthcare_summary,
            'AI Summary': ai_summary,
            'Collaborative Summary': collaborative_summary
        })

    # Convert results into DataFrame
    result_df = pd.DataFrame(results)

    # Save to CSV and return path
    output_file = "processed_results.csv"
    result_df.to_csv(output_file, index=False)
    return output_file

# Gradio UI components
def create_ui():
    with gr.Blocks() as demo:
        gr.Markdown("## Research Paper Summarization App")

        with gr.Tab("Single Abstract"):
            title_input = gr.Textbox(label="Paper Title", placeholder="Enter the paper title here")
            abstract_input = gr.Textbox(label="Paper Abstract", placeholder="Enter the paper abstract here", lines=5)
            single_output = gr.Textbox(label="Summarization Output", lines=5)

            # Button to process single abstract
            submit_btn_single = gr.Button("Process Abstract")
            submit_btn_single.click(process_single_abstract, inputs=[title_input, abstract_input], outputs=single_output)

        with gr.Tab("CSV Upload"):
            file_input = gr.File(label="Upload CSV file", file_types=[".csv"])
            output_file = gr.File(label="Download Processed Results")

            # Button to process CSV file
            submit_btn_csv = gr.Button("Process CSV")
            submit_btn_csv.click(process_csv, inputs=file_input, outputs=output_file)

    demo.launch()

# Create the Gradio UI
create_ui()