petter2025 commited on
Commit
c2c492b
·
verified ·
1 Parent(s): b00631a

Delete nli_detector.py

Browse files
Files changed (1) hide show
  1. nli_detector.py +0 -63
nli_detector.py DELETED
@@ -1,63 +0,0 @@
1
- """
2
- Natural Language Inference detector – checks if generated response is consistent with input.
3
- """
4
- import logging
5
- from typing import Optional
6
- import torch
7
- from transformers import pipeline
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- class NLIDetector:
12
- """
13
- Uses an NLI model to detect contradictions/hallucinations.
14
- Returns entailment probability (0 to 1) for a given premise‑hypothesis pair.
15
- """
16
-
17
- def __init__(self, model_name: str = "microsoft/deberta-base-mnli"):
18
- """
19
- Args:
20
- model_name: Hugging Face model identifier for NLI.
21
- Default is a public model that does not require authentication.
22
- """
23
- try:
24
- # Request all scores to obtain probabilities for each class.
25
- # The pipeline returns a list of lists of dicts: each dict has 'label' and 'score'.
26
- self.pipeline = pipeline(
27
- "text-classification",
28
- model=model_name,
29
- device=0 if torch.cuda.is_available() else -1,
30
- return_all_scores=True
31
- )
32
- logger.info(f"NLI model {model_name} loaded with return_all_scores=True.")
33
- except Exception as e:
34
- logger.error(f"Failed to load NLI model: {e}")
35
- self.pipeline = None
36
-
37
- def check(self, premise: str, hypothesis: str) -> Optional[float]:
38
- """
39
- Returns probability of entailment (higher means more consistent).
40
- Args:
41
- premise: The original input/context.
42
- hypothesis: The generated response.
43
- Returns:
44
- Float between 0 and 1, or None if model unavailable.
45
- """
46
- if self.pipeline is None:
47
- return None
48
- try:
49
- # For a single input, the pipeline returns a list containing one element,
50
- # which is itself a list of class-score dicts.
51
- result = self.pipeline(f"{premise} </s></s> {hypothesis}")
52
- # result[0] is the list of scores for all classes.
53
- scores = result[0]
54
- # Find the score corresponding to 'ENTAILMENT' (typical label for this model).
55
- for item in scores:
56
- if item['label'] == 'ENTAILMENT':
57
- return item['score']
58
- # If the label is not found (should not happen), fall back to 0.0.
59
- logger.warning("ENTAILMENT label not found in NLI output; returning 0.0.")
60
- return 0.0
61
- except Exception as e:
62
- logger.error(f"NLI error: {e}")
63
- return None