UraniaLi's picture
Update app.py
3a3fec2
raw
history blame
No virus
3.57 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import requests
from bs4 import BeautifulSoup
import time
import json
import xml.etree.ElementTree as ET
# Move models to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("stanford-crfm/BioMedLM", model_max_length= 1024, force_download= True)
model = AutoModelForCausalLM.from_pretrained("stanford-crfm/BioMedLM", force_download= True).to(device)
api_key = '2c78468d6246082d456a140bb1de415ed108'
num_results = 10
def extract_longer_answers_from_paragraphs(paragraphs, query, tokenizer, model):
context = " ".join(paragraphs)
question = f"What is the mechanism of {query}?"
context += question
inputs = tokenizer(context, return_tensors="pt", add_special_tokens=False, output_attentions=False).to(device)
top_p = 0.9 # Adjust as needed
max_len = 50 # Adjust as needed
outputs = model.generate(
**inputs,
top_p=top_p,
max_length=max_len,
num_beams=1, # Adjust as needed
no_repeat_ngram_size=2 # Adjust as needed
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
def retrieve_and_answer(query1, query2):
combined_query1 = f"({query1}) AND ({query2})"
answer = fetch_and_generate(query1, combined_query, tokenizer, model)
return answer1, answer2
def fetch_and_generate(query, combined_query, tokenizer, model):
esearch_url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=pubmed&api_key={api_key}&term={combined_query}&retmax={num_results}&sort=relevance"
headers = {'Accept': 'application/json'}
response = requests.get(esearch_url, headers=headers)
root = ET.fromstring(response.text)
if response.status_code == 200:
paragraphs = []
for article_id in root.find('IdList'):
article_id = article_id.text
efetch_url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=pubmed&api_key={api_key}&id={article_id}&retmode=xml"
response = requests.get(efetch_url)
if response.status_code == 200:
article_data = response.text
soup = BeautifulSoup(article_data, 'xml')
articles = soup.find_all('PubmedArticle')
for article in articles:
title = article.find('ArticleTitle')
if title:
title_text = title.text
if article.find('AbstractText'):
paragraphs.append(article.find('AbstractText').text)
else:
print("Error:", response.status_code)
time.sleep(3)
answer = extract_longer_answers_from_paragraphs(paragraphs, query, tokenizer, model)
return answer
else:
print("Error:", response.status_code)
return "Error fetching articles.", []
# Gradio Interface
iface = gr.Interface(
fn=retrieve_and_answer,
inputs=[gr.Textbox(placeholder="Enter Query 1", label= 'query1'),
gr.Textbox(placeholder="Enter Query 2", label= 'query2')],
outputs=[ gr.Textbox(placeholder="Answer from BioMedLM"), ],
live=True,
title="PubMed Question Answering: Stanford/BioMedLM",
description="Enter two queries to retrieve PubMed articles",
examples=[
["sertraline", "mechanism"],
["cancer", "treatment"]
]
)
iface.launch()