|
import os |
|
import json |
|
import logging |
|
from enum import Enum |
|
from pydantic import BaseModel, Field |
|
import pandas as pd |
|
from huggingface_hub import InferenceClient |
|
from tenacity import retry, stop_after_attempt, wait_exponential |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.INFO) |
|
|
|
file_handler = logging.FileHandler("hf_api.log") |
|
file_handler.setLevel(logging.INFO) |
|
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
console_handler.setFormatter(formatter) |
|
file_handler.setFormatter(formatter) |
|
|
|
|
|
if not logger.handlers: |
|
logger.addHandler(console_handler) |
|
logger.addHandler(file_handler) |
|
|
|
|
|
HF_TOKEN = os.environ.get('HF_TOKEN') |
|
if not HF_TOKEN: |
|
logger.error("Hugging Face API token not found. Set the HF_TOKEN environment variable.") |
|
raise EnvironmentError("HF_TOKEN environment variable is not set.") |
|
|
|
|
|
MODEL_NAME1 = "meta-llama/Llama-3.1-8B-Instruct" |
|
MODEL_NAME2 = "Qwen/Qwen2.5-72B-Instruct" |
|
try: |
|
client1 = InferenceClient(model=MODEL_NAME1, token=HF_TOKEN) |
|
logger.info(f"InferenceClient for model '{MODEL_NAME1}' instantiated successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to instantiate InferenceClient for model '{MODEL_NAME1}': {e}") |
|
raise |
|
|
|
try: |
|
client2 = InferenceClient(model=MODEL_NAME2, token=HF_TOKEN) |
|
logger.info(f"InferenceClient for model '{MODEL_NAME2}' instantiated successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to instantiate InferenceClient for model '{MODEL_NAME2}': {e}") |
|
raise |
|
|
|
|
|
class EvaluationSchema(BaseModel): |
|
reasoning: str |
|
relevance_score: int = Field(ge=0, le=10) |
|
|
|
class TopicEnum(Enum): |
|
Rheumatoid_Arthritis = "Rheumatoid Arthritis" |
|
Systemic_Lupus_Erythematosus = "Systemic Lupus Erythematosus" |
|
Scleroderma = "Scleroderma" |
|
Sjogren_s_Disease = "Sjogren's Disease" |
|
Ankylosing_Spondylitis = "Ankylosing Spondylitis" |
|
Psoriatic_Arthritis = "Psoriatic Arthritis" |
|
Gout = "Gout" |
|
Vasculitis = "Vasculitis" |
|
Osteoarthritis = "Osteoarthritis" |
|
Infectious_Diseases = "Infectious Diseases" |
|
Immunology = "Immunology" |
|
Genetics = "Genetics" |
|
Biologics = "Biologics" |
|
Biosimilars = "Biosimilars" |
|
Small_Molecules = "Small Molecules" |
|
Clinical_Trials = "Clinical Trials" |
|
Health_Policy = "Health Policy" |
|
Patient_Education = "Patient Education" |
|
Other_Rheumatic_Diseases = "Other Rheumatic Diseases" |
|
|
|
class SummarySchema(BaseModel): |
|
summary: str |
|
|
|
topic: TopicEnum = TopicEnum.Other_Rheumatic_Diseases |
|
|
|
class PaperSchema(BaseModel): |
|
title: str |
|
authors: str |
|
journal: str |
|
pmid: str |
|
|
|
class TopicSummarySchema(BaseModel): |
|
planning: str |
|
summary: str |
|
|
|
def evaluate_relevance(title: str, abstract: str) -> EvaluationSchema: |
|
prompt = f""" |
|
Title: {title} |
|
Abstract: {abstract} |
|
Instructions: Evaluate the relevance of this medical abstract for an audience of rheumatologists on a scale of 0 to 10 with 10 being reserved only for large clinical trials in rheumatology. |
|
Be very discerning and only give a score above 8 for papers that are highly clinically relevant to rheumatologists. |
|
Respond in JSON format using the following schema: |
|
{json.dumps(EvaluationSchema.model_json_schema())} |
|
""" |
|
|
|
try: |
|
response = client1.text_generation( |
|
prompt, |
|
max_new_tokens=512, |
|
temperature=0.2, |
|
grammar={"type": "json", "value": EvaluationSchema.model_json_schema()} |
|
) |
|
result = json.loads(response) |
|
return result |
|
except Exception as e: |
|
logger.error(f"Error in evaluate_relevance: {e}") |
|
raise |
|
|
|
def summarize_abstract(abstract: str) -> SummarySchema: |
|
prompt = f""" |
|
Abstract: {abstract} |
|
Instructions: Summarize this medical abstract in 1 sentence and select the most relevant topic from the following enum: |
|
{TopicEnum.__doc__} |
|
Respond in JSON format using the following schema: |
|
{json.dumps(SummarySchema.model_json_schema())} |
|
""" |
|
|
|
try: |
|
response = client1.text_generation( |
|
prompt, |
|
max_new_tokens=512, |
|
temperature=0.2, |
|
grammar={"type": "json", "value": SummarySchema.model_json_schema()} |
|
) |
|
result = json.loads(response) |
|
return result |
|
except Exception as e: |
|
logger.error(f"Error in summarize_abstract: {e}") |
|
raise |
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) |
|
def _make_api_call(client, prompt, max_tokens=4096, temp=0.2, schema=None): |
|
try: |
|
response = client.text_generation( |
|
prompt, |
|
max_new_tokens=max_tokens, |
|
temperature=temp, |
|
grammar={"type": "json", "value": schema} if schema else None |
|
) |
|
return json.loads(response) |
|
except Exception as e: |
|
logger.error(f"API call failed: {e}") |
|
raise |
|
|
|
def compose_newsletter(papers: pd.DataFrame) -> str: |
|
if papers.empty: |
|
logger.info("No papers provided to compose the newsletter.") |
|
return "" |
|
|
|
content = ["# This Week in Rheumatology\n"] |
|
topics = papers['Topic'].unique() |
|
|
|
for topic in topics: |
|
try: |
|
relevant_papers = papers[papers['Topic'] == topic] |
|
|
|
papers_dict = relevant_papers.rename(columns={ |
|
'Title': 'title', |
|
'Authors': 'authors', |
|
'Journal': 'journal', |
|
'PMID': 'pmid', |
|
'Summary': 'summary' |
|
}).to_dict('records') |
|
|
|
prompt = f""" |
|
Instructions: Generate a brief summary of the latest research on {topic} using the following papers. |
|
Papers: {json.dumps(papers_dict)} |
|
Respond in JSON format using the following schema: |
|
{json.dumps(TopicSummarySchema.model_json_schema())} |
|
You have the option of using the planning field first to organize your thoughts before writing the summary. |
|
The summary should be concise, but because you are summarizing several papers, it should be detailed enough to give the reader a good idea of the latest research in the field. |
|
The papers may be somewhat disjointed, so you will need to think carefully about how you can transition between them with clever wording. |
|
You can use anywhere from 1 to 3 paragraphs for the summary. |
|
""" |
|
|
|
result = _make_api_call( |
|
client2, |
|
prompt, |
|
max_tokens=4096, |
|
temp=0.2, |
|
schema=TopicSummarySchema.model_json_schema() |
|
) |
|
|
|
|
|
logger.debug(f"Raw response from Hugging Face: {result}") |
|
|
|
|
|
summary = TopicSummarySchema(**result) |
|
|
|
|
|
topic_content = f"## {topic}\n\n" |
|
topic_content += f"{summary.summary}\n\n" |
|
|
|
|
|
topic_content += "### References\n\n" |
|
relevant_papers = papers[papers['Topic'] == topic] |
|
for _, paper in relevant_papers.iterrows(): |
|
topic_content += (f"- {paper['Title']} by {paper['Authors']}. {paper['Journal']}. " |
|
f"[PMID: {paper['PMID']}](https://pubmed.ncbi.nlm.nih.gov/{paper['PMID']}/)\n") |
|
|
|
content.append(topic_content) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing topic {topic}: {e}") |
|
logger.error(f"Raw response: {result}") |
|
continue |
|
|
|
return "\n".join(content) |
|
|