|
import os |
|
import requests |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
HF_BART_MODEL = "facebook/bart-large-mnli" |
|
HF_API_URL = f"https://api-inference.huggingface.co/models/{HF_BART_MODEL}" |
|
|
|
HEADERS = { |
|
"Authorization": f"Bearer {HF_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
def run_zero_shot_classification(prompt: str, candidate_labels: list[str]) -> str: |
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"candidate_labels": candidate_labels |
|
} |
|
} |
|
|
|
try: |
|
response = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=20) |
|
response.raise_for_status() |
|
data = response.json() |
|
if isinstance(data, dict) and "labels" in data and data["labels"]: |
|
return data["labels"][0] |
|
except Exception as e: |
|
print("[BART Zero-Shot Error]:", str(e)) |
|
|
|
return "" |
|
|
|
def run_entailment_check(premise: str, hypothesis: str) -> bool: |
|
payload = { |
|
"inputs": { |
|
"premise": premise, |
|
"hypothesis": hypothesis |
|
} |
|
} |
|
|
|
try: |
|
response = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=20) |
|
response.raise_for_status() |
|
data = response.json() |
|
if isinstance(data, dict) and "labels" in data: |
|
labels = data["labels"] |
|
scores = data["scores"] |
|
if "entailment" in labels: |
|
entailment_score = scores[labels.index("entailment")] |
|
return entailment_score > 0.5 |
|
except Exception as e: |
|
print("[BART Entailment Error]:", str(e)) |
|
|
|
return False |
|
|