FastAPI / models_initialization /bart_large_registry.py
raghavNCI
too many changes, hope this works
f00f379
raw
history blame
1.66 kB
import os
import requests
from dotenv import load_dotenv
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
HF_BART_MODEL = "facebook/bart-large-mnli"
HF_API_URL = f"https://api-inference.huggingface.co/models/{HF_BART_MODEL}"
HEADERS = {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json"
}
def run_zero_shot_classification(prompt: str, candidate_labels: list[str]) -> str:
payload = {
"inputs": prompt,
"parameters": {
"candidate_labels": candidate_labels
}
}
try:
response = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=20)
response.raise_for_status()
data = response.json()
if isinstance(data, dict) and "labels" in data and data["labels"]:
return data["labels"][0] # Most confident label
except Exception as e:
print("[BART Zero-Shot Error]:", str(e))
return ""
def run_entailment_check(premise: str, hypothesis: str) -> bool:
payload = {
"inputs": {
"premise": premise,
"hypothesis": hypothesis
}
}
try:
response = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=20)
response.raise_for_status()
data = response.json()
if isinstance(data, dict) and "labels" in data:
labels = data["labels"]
scores = data["scores"]
if "entailment" in labels:
entailment_score = scores[labels.index("entailment")]
return entailment_score > 0.5
except Exception as e:
print("[BART Entailment Error]:", str(e))
return False