|
import gradio as gr |
|
import pandas as pd |
|
from transformers import BartTokenizer, BartForConditionalGeneration, pipeline, T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') |
|
bart_model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') |
|
t5_tokenizer = T5Tokenizer.from_pretrained('t5-small') |
|
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small') |
|
|
|
|
|
summarizer = pipeline("summarization", model=bart_model, tokenizer=bart_tokenizer) |
|
|
|
|
|
healthcare_keywords = ["disease", "cancer", "patient", "treatment", "health", "illness", "medicine", "symptom", "diagnosis", "epidemic", "infection"] |
|
ai_keywords = ["algorithm", "artificial intelligence", "machine learning", "neural network", "AI", "model", "deep learning", "prediction", "data"] |
|
|
|
|
|
def classify_domain(title, abstract): |
|
healthcare_detected = any(keyword.lower() in (title + abstract).lower() for keyword in healthcare_keywords) |
|
ai_detected = any(keyword.lower() in (title + abstract).lower() for keyword in ai_keywords) |
|
|
|
if healthcare_detected and ai_detected: |
|
return "Healthcare, AI" |
|
elif healthcare_detected: |
|
return "HealthCare" |
|
elif ai_detected: |
|
return "AI" |
|
return "General" |
|
|
|
|
|
def extractive_summary(text): |
|
summary = summarizer(text, max_length=150, min_length=50, do_sample=False) |
|
return summary[0]['summary_text'] |
|
|
|
|
|
def healthcare_agent(abstract): |
|
|
|
healthcare_relevant_text = " ".join([sentence for sentence in abstract.split('.') if any(keyword in sentence.lower() for keyword in healthcare_keywords)]) |
|
|
|
|
|
if healthcare_relevant_text: |
|
healthcare_summary = extractive_summary(healthcare_relevant_text) |
|
return healthcare_summary |
|
else: |
|
return "Not related to Healthcare" |
|
|
|
|
|
def ai_agent(abstract): |
|
|
|
ai_relevant_text = " ".join([sentence for sentence in abstract.split('.') if any(keyword in sentence.lower() for keyword in ai_keywords)]).strip() |
|
if ai_relevant_text: |
|
ai_summary = extractive_summary(ai_relevant_text) |
|
return ai_summary |
|
else: |
|
return "Not related to AI" |
|
|
|
|
|
def generate_general_summary(abstract): |
|
|
|
input_text = f"summarize: {abstract}" |
|
input_ids = t5_tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True) |
|
summary_ids = t5_model.generate(input_ids, max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True) |
|
summary = t5_tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return summary |
|
|
|
|
|
def generate_collaborative_insights(abstract, title, domain): |
|
|
|
general_summary = generate_general_summary(abstract) |
|
healthcare_summary = healthcare_agent(abstract) |
|
ai_summary = ai_agent(abstract) |
|
|
|
|
|
if domain == "Healthcare, AI": |
|
collaborative_summary = f"Collaborative Insights between Healthcare and AI: {healthcare_summary} {ai_summary}" |
|
else: |
|
collaborative_summary = "Not related to both Healthcare and AI" |
|
|
|
|
|
if domain == "General": |
|
healthcare_summary = "Not related to Healthcare" |
|
ai_summary = "Not related to AI" |
|
|
|
return general_summary, healthcare_summary, ai_summary, collaborative_summary |
|
|
|
|
|
def process_single_abstract(title, abstract): |
|
domain = classify_domain(title, abstract) |
|
general_summary, healthcare_summary, ai_summary, collaborative_summary = generate_collaborative_insights(abstract, title, domain) |
|
return general_summary, healthcare_summary, ai_summary, collaborative_summary |
|
|
|
|
|
def process_csv(file): |
|
df = pd.read_csv(file.name) |
|
|
|
|
|
if 'Title' not in df.columns or 'Abstract' not in df.columns: |
|
return "CSV file must contain 'Title' and 'Abstract' columns." |
|
|
|
|
|
results = [] |
|
|
|
|
|
for index, row in df.iterrows(): |
|
title = row['Title'] |
|
abstract = row['Abstract'] |
|
|
|
|
|
domain = classify_domain(title, abstract) |
|
|
|
|
|
general_summary, healthcare_summary, ai_summary, collaborative_summary = generate_collaborative_insights(abstract, title, domain) |
|
|
|
|
|
results.append({ |
|
'Title': title, |
|
'Abstract': abstract, |
|
'Domain': domain, |
|
'General Summary': general_summary, |
|
'HealthCare Summary': healthcare_summary, |
|
'AI Summary': ai_summary, |
|
'Collaborative Summary': collaborative_summary |
|
}) |
|
|
|
|
|
result_df = pd.DataFrame(results) |
|
|
|
|
|
output_file = "processed_results.csv" |
|
result_df.to_csv(output_file, index=False) |
|
return output_file |
|
|
|
|
|
def create_ui(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Research Paper Summarization App") |
|
|
|
with gr.Tab("Single Abstract"): |
|
title_input = gr.Textbox(label="Paper Title", placeholder="Enter the paper title here") |
|
abstract_input = gr.Textbox(label="Paper Abstract", placeholder="Enter the paper abstract here", lines=5) |
|
single_output = gr.Textbox(label="Summarization Output", lines=5) |
|
|
|
|
|
submit_btn_single = gr.Button("Process Abstract") |
|
submit_btn_single.click(process_single_abstract, inputs=[title_input, abstract_input], outputs=single_output) |
|
|
|
with gr.Tab("CSV Upload"): |
|
file_input = gr.File(label="Upload CSV file", file_types=[".csv"]) |
|
output_file = gr.File(label="Download Processed Results") |
|
|
|
|
|
submit_btn_csv = gr.Button("Process CSV") |
|
submit_btn_csv.click(process_csv, inputs=file_input, outputs=output_file) |
|
|
|
demo.launch() |
|
|
|
|
|
create_ui() |
|
|