File size: 1,659 Bytes
f00f379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import requests
from dotenv import load_dotenv

load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")
HF_BART_MODEL = "facebook/bart-large-mnli"
HF_API_URL = f"https://api-inference.huggingface.co/models/{HF_BART_MODEL}"

HEADERS = {
    "Authorization": f"Bearer {HF_TOKEN}",
    "Content-Type": "application/json"
}

def run_zero_shot_classification(prompt: str, candidate_labels: list[str]) -> str:
    payload = {
        "inputs": prompt,
        "parameters": {
            "candidate_labels": candidate_labels
        }
    }

    try:
        response = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=20)
        response.raise_for_status()
        data = response.json()
        if isinstance(data, dict) and "labels" in data and data["labels"]:
            return data["labels"][0]  # Most confident label
    except Exception as e:
        print("[BART Zero-Shot Error]:", str(e))

    return ""

def run_entailment_check(premise: str, hypothesis: str) -> bool:
    payload = {
        "inputs": {
            "premise": premise,
            "hypothesis": hypothesis
        }
    }

    try:
        response = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=20)
        response.raise_for_status()
        data = response.json()
        if isinstance(data, dict) and "labels" in data:
            labels = data["labels"]
            scores = data["scores"]
            if "entailment" in labels:
                entailment_score = scores[labels.index("entailment")]
                return entailment_score > 0.5
    except Exception as e:
        print("[BART Entailment Error]:", str(e))

    return False