Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
from Bio import Entrez | |
import requests | |
import os | |
HF_API = os.getenv('HF_API') | |
openai_api_key = os.getenv('OPENAI_API') | |
PASSWORD = os.getenv('password') | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
if False: | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto",trust_remote_code=True).eval() | |
def generate_summary(prompt): | |
# Add instructions to the prompt to signal that you want a summary | |
instructions = "Summarize the following text:" | |
prompt_with_instructions = f"{instructions}\n{prompt}" | |
# Tokenize the prompt text and return PyTorch tensors | |
inputs = tokenizer.encode(prompt_with_instructions, return_tensors="pt") | |
# Generate a response using the model | |
outputs = model.generate(inputs, max_length=512, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id) | |
# Decode the response | |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return summary | |
def generate_response(prompt): | |
# Tokenize the prompt text and return PyTorch tensors | |
inputs = tokenizer.encode(prompt, return_tensors="pt") | |
# Generate a response using the model | |
outputs = model.generate(inputs, max_length=512, num_return_sequences=1) | |
# Decode the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
def search_pubmed_v2(query, retmax=5, mindate=None, maxdate=None, datetype="pdat"): | |
Entrez.email = 'your.email@example.com' # Always set the Entrez.email to tell NCBI who you are | |
search_kwargs = { | |
"db": "pubmed", | |
"term": query, | |
"retmax": retmax, | |
"sort": 'relevance', | |
"datetype": datetype | |
} | |
# If dates are provided, add them to the search arguments | |
if mindate: | |
search_kwargs["mindate"] = mindate | |
if maxdate: | |
search_kwargs["maxdate"] = maxdate | |
handle = Entrez.esearch(**search_kwargs) | |
record = Entrez.read(handle) | |
handle.close() | |
idlist = record['IdList'] | |
handle = Entrez.efetch(db="pubmed", id=idlist, retmode="xml") | |
articles = Entrez.read(handle)['PubmedArticle'] | |
handle.close() | |
# ... (the rest of your existing code to extract article information) | |
abstracts = [] | |
for article in articles: | |
article_id = article['MedlineCitation']['PMID'] | |
authors = ' '.join([author['LastName'] + ' ' + author.get('Initials', '') | |
for author in article['MedlineCitation']['Article'].get('AuthorList', [])]), | |
article_title = article['MedlineCitation']['Article']['ArticleTitle'] | |
abstract_text = article['MedlineCitation']['Article'].get('Abstract', {}).get('AbstractText', []) | |
if isinstance(abstract_text, list): | |
# Join the list elements if abstract is a list | |
abstract_text = " ".join(abstract_text) | |
abstracts.append((article_id, authors, article_title, abstract_text)) | |
return pd.DataFrame(abstracts) | |
# Function to search PubMed for articles | |
def search_pubmed(query, retmax=5, mindate=None, maxdate=None, datetype="pdat"): | |
Entrez.email = 'example@example.com' | |
search_kwargs = { | |
"db": "pubmed", | |
"term": query, | |
"retmax": retmax, | |
"sort": 'relevance', | |
"datetype": datetype | |
} | |
# If dates are provided, add them to the search arguments | |
if mindate: | |
search_kwargs["mindate"] = mindate | |
if maxdate: | |
search_kwargs["maxdate"] = maxdate | |
handle = Entrez.esearch(**search_kwargs) | |
record = Entrez.read(handle) | |
handle.close() | |
idlist = record['IdList'] | |
handle = Entrez.efetch(db="pubmed", id=idlist, retmode="xml") | |
articles = Entrez.read(handle)['PubmedArticle'] | |
handle.close() | |
article_list = [] | |
for article in articles: | |
abstract_text = article['MedlineCitation']['Article'].get('Abstract', {}).get('AbstractText', []) | |
if isinstance(abstract_text, list): | |
# Join the list elements if abstract is a list | |
abstract_text = " ".join(abstract_text) | |
article_dict = { | |
'PMID': str(article['MedlineCitation']['PMID']), | |
'Authors': ' '.join([author['LastName'] + ' ' + author.get('Initials', '') | |
for author in article['MedlineCitation']['Article'].get('AuthorList', [])]), | |
'Title': article['MedlineCitation']['Article']['ArticleTitle'], | |
'Abstract': abstract_text, | |
} | |
article_list.append(article_dict) | |
return pd.DataFrame(article_list) | |
# Function to format search results for OpenAI summarization | |
def format_results_for_openai(table_data): | |
# Combine title and abstract for each record into one string for summarization | |
summaries = [] | |
for _, row in table_data.iterrows(): | |
summary = f"Title: {row['Title']}\nAuthors:{row['Authors']}\nAbstract: {row['Abstract']}\n" | |
summaries.append(summary) | |
print(summaries) | |
return "\n".join(summaries) | |
def get_summary_from_openai(text_to_summarize, openai_api_key): | |
headers = { | |
'Authorization': f'Bearer {openai_api_key}', | |
'Content-Type': 'application/json' | |
} | |
data = { | |
"model": "gpt-3.5-turbo", # Specify the GPT-3.5-turbo model | |
"messages": [{"role": "system", "content": '''Please summarize the following PubMed search results, | |
including the authors who conducted the research, the main research subject, and the major findings. | |
Please compare the difference among these articles. | |
Please return your results in a single paragraph in the regular scientific paper fashion for each article:'''}, | |
{"role": "user", "content": text_to_summarize}], | |
} | |
response = requests.post('https://api.openai.com/v1/chat/completions', headers=headers, json=data) | |
if response.status_code == 200: | |
summary = response.json().get('choices', [{}])[0].get('message', {'content':''}).get('content', '').strip() | |
return summary | |
else: | |
# Print the error message if the API call was unsuccessful | |
print(f"Error: {response.status_code}") | |
print(response.text) | |
return None | |
# Function that combines PubMed search with OpenAI summarization | |
def summarize_pubmed_search(search_results): | |
formatted_text = format_results_for_openai(search_results) | |
summary = get_summary_from_openai(formatted_text, openai_api_key) # Replace with your actual OpenAI API key | |
return summary | |
# Function to summarize articles using Hugging Face's API | |
def summarize_with_huggingface(model, selected_articles, password): | |
if password == PASSWORD: | |
summary = summarize_pubmed_search(selected_articles) | |
return summary | |
else: | |
API_URL = f"https://api-inference.huggingface.co/models/{model}" | |
# Your Hugging Face API key | |
API_KEY = HF_API | |
headers = {"Authorization": f"Bearer {API_KEY}"} | |
# Prepare the text to summarize: concatenate all abstracts | |
print(type(selected_articles)) | |
print(selected_articles.to_dict(orient='records')) | |
text_to_summarize = " ".join( | |
[f"PMID: {article['PMID']}. Authors: {article['Authors']}. Title: {article['Title']}. Abstract: {article['Abstract']}." | |
for article in selected_articles.to_dict(orient='records')] | |
) | |
# Define the payload | |
payload = { | |
"inputs": text_to_summarize, | |
"parameters": {"max_length": 300} # Adjust as needed | |
} | |
USE_LOCAL=False | |
if USE_LOCAL: | |
response = generate_response(text_to_summarize) | |
else: | |
# Make the POST request to the Hugging Face API | |
response = requests.post(API_URL, headers=headers, json=payload) | |
response.raise_for_status() # Raise an HTTPError if the HTTP request returned an unsuccessful status code | |
# The API returns a list of dictionaries. We extract the summary from the first one. | |
return response.json()[0]['generated_text'] | |
import gradio as gr | |
from Bio import Entrez | |
# Always tell NCBI who you are | |
Entrez.email = "your.email@example.com" | |
def process_query(keywords, top_k): | |
articles = search_pubmed(keywords, top_k) | |
# Convert each article from a dictionary to a list of values in the correct order | |
articles_for_display = [[article['pmid'], article['authors'], article['title'], article['abstract']] for article in articles] | |
return articles_for_display | |
def summarize_articles(indices, articles_for_display): | |
# Convert indices to a list of integers | |
selected_indices = [int(index.strip()) for index in indices.split(',') if index.strip().isdigit()] | |
# Convert the DataFrame to a list of dictionaries | |
articles_list = articles_for_display.to_dict(orient='records') | |
# Select articles based on the provided indices | |
selected_articles = [articles_list[index] for index in selected_indices] | |
# Generate the summary | |
summary = summarize_with_huggingface(selected_articles) | |
return summary | |
def check_password(username, password): | |
if username == USERNAME and password == PASSWORD: | |
return True, "Welcome!" | |
else: | |
return False, "Incorrect username or password." | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("### PubMed Article Summarizer") | |
with gr.Row(): | |
password_input = gr.Textbox(label="Enter the password") | |
model_input = gr.Textbox(label="Enter the model to use", value="h2oai/h2ogpt-4096-llama2-7b-chat") | |
with gr.Row(): | |
startdate = gr.Textbox(label="Starting year") | |
enddate = gr.Textbox(label="End year") | |
query_input = gr.Textbox(label="Query Keywords") | |
retmax_input = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of articles") | |
search_button = gr.Button("Search") | |
output_table = gr.Dataframe(headers=["PMID", "Authors", "Title","Abstract" ]) | |
summarize_button = gr.Button("Summarize") | |
summary_output = gr.Textbox() | |
def update_output_table(query, retmax, startdate, enddate): | |
df = search_pubmed(query, retmax, startdate, enddate) | |
# output_table.update(value=df) | |
return df | |
search_button.click(update_output_table, inputs=[query_input, retmax_input, startdate, enddate], outputs=output_table) | |
summarize_button.click(fn=summarize_with_huggingface, inputs=[model_input, output_table, password_input], outputs=summary_output) | |
demo.launch(debug=True) | |
if False: | |
with gr.Blocks() as demo: | |
gr.Markdown("### PubMed Article Summarizer") | |
with gr.Row(): | |
query_input = gr.Textbox(label="Query Keywords") | |
top_k_input = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Top K Results") | |
search_button = gr.Button("Search") | |
output_table = gr.Dataframe(headers=["Title", "Authors", "Abstract", "PMID"]) | |
indices_input = gr.Textbox(label="Enter indices of articles to summarize (comma-separated)") | |
summarize_button = gr.Button("Summarize Selected Articles") | |
summary_output = gr.Textbox(label="Summary") | |
search_button.click( | |
fn=process_query, | |
inputs=[query_input, top_k_input], | |
outputs=output_table | |
) | |
summarize_button.click( | |
fn=summarize_articles, | |
inputs=[indices_input, output_table], | |
outputs=summary_output | |
) | |
demo.launch(auth=("user", "pass1234"), debug=True) |