pavankumarvk commited on
Commit
28c4d49
Β·
verified Β·
1 Parent(s): 90121fd

Update text_detector_inference.py

Browse files
Files changed (1) hide show
  1. text_detector_inference.py +117 -37
text_detector_inference.py CHANGED
@@ -2,68 +2,96 @@
2
  text_detector_inference.py
3
  ==========================
4
  Inference wrapper for HybridAITextDetector.
5
- Designed to be imported by app.py (Gradio) in the Hugging Face Space.
 
 
 
 
 
6
 
7
  Usage
8
  -----
9
  from text_detector_inference import TextDetectorInference
10
 
11
- detector = TextDetectorInference(
12
- checkpoint="best_text_detector.pt",
13
- threshold=0.5
14
- )
15
- result = detector.predict("Some text here...")
16
  """
17
 
18
  import os
19
  import torch
20
- from transformers import AutoTokenizer
21
  from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH
22
 
 
 
 
 
 
 
23
 
24
  class TextDetectorInference:
25
  """
26
- Thin wrapper around HybridAITextDetector for single-text prediction.
 
27
 
28
  Parameters
29
  ----------
30
  checkpoint : str
31
- Path to the .pt state-dict file.
32
  threshold : float
33
  Decision boundary for the sigmoid probability (default 0.5).
34
- Set to the optimal F1 threshold found during evaluation.
35
- device : torch.device or None
36
  Auto-detects CUDA if None.
37
  """
38
 
39
  def __init__(
40
  self,
41
- checkpoint: str = "best_text_detector.pt",
42
  threshold: float = 0.5,
43
  device: torch.device = None,
44
  ):
45
- self.threshold = threshold
46
- self.device = device or torch.device(
47
  "cuda" if torch.cuda.is_available() else "cpu"
48
  )
49
-
50
- print(f"[TextDetector] Loading tokenizer from {MODEL_NAME}...")
51
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
52
 
53
  if os.path.exists(checkpoint):
54
- print(f"[TextDetector] Loading checkpoint: {checkpoint}")
 
 
 
55
  self.model = HybridAITextDetector()
56
  self.model.load_state_dict(
57
  torch.load(checkpoint, map_location=self.device)
58
  )
59
  self.model.eval().to(self.device)
60
- print("[TextDetector] βœ… Model ready")
 
61
  else:
62
- print(f"[TextDetector] ⚠️ Checkpoint '{checkpoint}' not found. "
63
- "Model NOT loaded β€” predictions will fail.")
64
- self.model = None
65
-
66
- # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def predict(self, text: str) -> dict:
68
  """
69
  Classify a single text string.
@@ -72,17 +100,31 @@ class TextDetectorInference:
72
  -------
73
  dict with keys:
74
  label : "AI-Generated" or "Human-Written"
75
- confidence : probability of the predicted class (0-1)
76
  ai_prob : raw P(AI-generated)
77
  human_prob : 1 - ai_prob
 
78
  """
79
- if self.model is None:
80
- return {"error": "Model not loaded β€” missing checkpoint file."}
81
-
82
  text = text.strip()
83
  if not text:
84
  return {"error": "Input text is empty."}
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  enc = self.tokenizer(
87
  text,
88
  truncation=True,
@@ -90,7 +132,6 @@ class TextDetectorInference:
90
  max_length=MAX_LENGTH,
91
  return_tensors="pt",
92
  )
93
-
94
  input_ids = enc["input_ids"].to(self.device)
95
  attention_mask = enc["attention_mask"].to(self.device)
96
  token_type_ids = enc.get(
@@ -99,8 +140,8 @@ class TextDetectorInference:
99
  ).to(self.device)
100
 
101
  with torch.no_grad():
102
- logit = self.model(input_ids, attention_mask, token_type_ids)
103
- ai_prob = torch.sigmoid(logit).item()
104
 
105
  human_prob = 1.0 - ai_prob
106
  is_ai = ai_prob >= self.threshold
@@ -112,20 +153,59 @@ class TextDetectorInference:
112
  "confidence": round(confidence, 4),
113
  "ai_prob": round(ai_prob, 4),
114
  "human_prob": round(human_prob, 4),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  }
116
 
117
- # ------------------------------------------------------------------
118
- def predict_batch(self, texts: list[str]) -> list[dict]:
119
  """Run predict() on a list of texts. Returns list of result dicts."""
120
  return [self.predict(t) for t in texts]
121
 
122
- # ------------------------------------------------------------------
123
- def format_for_gradio(self, text: str) -> tuple[str, float, dict]:
124
  """
125
- Convenience wrapper that returns values in a Gradio-friendly format:
126
  (label_string, confidence_float, full_result_dict)
127
  """
128
  result = self.predict(text)
129
  if "error" in result:
130
  return result["error"], 0.0, result
131
- return result["label"], result["confidence"], result
 
2
  text_detector_inference.py
3
  ==========================
4
  Inference wrapper for HybridAITextDetector.
5
+
6
+ Strategy
7
+ --------
8
+ 1. If ``best_text_detector.pt`` exists β†’ load the custom trained model.
9
+ 2. Otherwise β†’ fall back to a pretrained
10
+ HuggingFace AI-text detector so the Space keeps working immediately.
11
 
12
  Usage
13
  -----
14
  from text_detector_inference import TextDetectorInference
15
 
16
+ detector = TextDetectorInference() # auto-detects checkpoint
17
+ result = detector.predict("Some text…")
 
 
 
18
  """
19
 
20
  import os
21
  import torch
22
+ from transformers import AutoTokenizer, pipeline as hf_pipeline
23
  from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH
24
 
25
+ # ─── Fallback model ───────────────────────────────────────────────────────────
26
+ # Used when best_text_detector.pt is not present in the Space.
27
+ # "Hello-SimpleAI/chatgpt-detector-roberta" is a publicly available,
28
+ # well-validated AI-text detector (RoBERTa fine-tuned on ChatGPT outputs).
29
+ FALLBACK_MODEL_ID = "Hello-SimpleAI/chatgpt-detector-roberta"
30
+
31
 
32
  class TextDetectorInference:
33
  """
34
+ Thin wrapper around HybridAITextDetector (or a fallback pretrained model)
35
+ for single-text prediction.
36
 
37
  Parameters
38
  ----------
39
  checkpoint : str
40
+ Path to the .pt state-dict file for the custom model.
41
  threshold : float
42
  Decision boundary for the sigmoid probability (default 0.5).
43
+ Set to the optimal F1 threshold found during your training run.
44
+ device : torch.device | None
45
  Auto-detects CUDA if None.
46
  """
47
 
48
  def __init__(
49
  self,
50
+ checkpoint: str = "best_text_detector.pt",
51
  threshold: float = 0.5,
52
  device: torch.device = None,
53
  ):
54
+ self.threshold = threshold
55
+ self.device = device or torch.device(
56
  "cuda" if torch.cuda.is_available() else "cpu"
57
  )
58
+ self._use_custom = False
59
+ self._fallback = None
60
+ self.model = None
61
+ self.tokenizer = None
62
 
63
  if os.path.exists(checkpoint):
64
+ # ── Load custom trained HybridAITextDetector ──────────────────────
65
+ print(f"[TextDetector] βœ… Checkpoint found: {checkpoint}")
66
+ print(f"[TextDetector] Loading tokenizer from {MODEL_NAME} …")
67
+ self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
68
  self.model = HybridAITextDetector()
69
  self.model.load_state_dict(
70
  torch.load(checkpoint, map_location=self.device)
71
  )
72
  self.model.eval().to(self.device)
73
+ self._use_custom = True
74
+ print("[TextDetector] βœ… Custom model ready")
75
  else:
76
+ # ── Fall back to pretrained HuggingFace model ─────────────────────
77
+ print(
78
+ f"[TextDetector] ⚠️ '{checkpoint}' not found.\n"
79
+ f"[TextDetector] Loading pretrained fallback: {FALLBACK_MODEL_ID}"
80
+ )
81
+ try:
82
+ self._fallback = hf_pipeline(
83
+ "text-classification",
84
+ model=FALLBACK_MODEL_ID,
85
+ device=0 if torch.cuda.is_available() else -1,
86
+ truncation=True,
87
+ max_length=512,
88
+ )
89
+ print(f"[TextDetector] βœ… Fallback model ready ({FALLBACK_MODEL_ID})")
90
+ except Exception as e:
91
+ print(f"[TextDetector] ❌ Fallback model failed to load: {e}")
92
+ self._fallback = None
93
+
94
+ # ─────────────────────────────────────────────────────��────────────────────
95
  def predict(self, text: str) -> dict:
96
  """
97
  Classify a single text string.
 
100
  -------
101
  dict with keys:
102
  label : "AI-Generated" or "Human-Written"
103
+ confidence : probability of the predicted class (0–1)
104
  ai_prob : raw P(AI-generated)
105
  human_prob : 1 - ai_prob
106
+ source : "custom_model" | "pretrained_fallback"
107
  """
 
 
 
108
  text = text.strip()
109
  if not text:
110
  return {"error": "Input text is empty."}
111
 
112
+ if self._use_custom:
113
+ return self._predict_custom(text)
114
+ elif self._fallback is not None:
115
+ return self._predict_fallback(text)
116
+ else:
117
+ return {
118
+ "error": (
119
+ "No model available. Upload 'best_text_detector.pt' to the "
120
+ "Space, or check your internet connection so the fallback "
121
+ "model can be downloaded."
122
+ )
123
+ }
124
+
125
+ # ──────────────────────────────────────────────────────────────────────────
126
+ def _predict_custom(self, text: str) -> dict:
127
+ """Run inference with the custom HybridAITextDetector checkpoint."""
128
  enc = self.tokenizer(
129
  text,
130
  truncation=True,
 
132
  max_length=MAX_LENGTH,
133
  return_tensors="pt",
134
  )
 
135
  input_ids = enc["input_ids"].to(self.device)
136
  attention_mask = enc["attention_mask"].to(self.device)
137
  token_type_ids = enc.get(
 
140
  ).to(self.device)
141
 
142
  with torch.no_grad():
143
+ logit = self.model(input_ids, attention_mask, token_type_ids)
144
+ ai_prob = torch.sigmoid(logit).item()
145
 
146
  human_prob = 1.0 - ai_prob
147
  is_ai = ai_prob >= self.threshold
 
153
  "confidence": round(confidence, 4),
154
  "ai_prob": round(ai_prob, 4),
155
  "human_prob": round(human_prob, 4),
156
+ "source": "custom_model",
157
+ }
158
+
159
+ # ──────────────────────────────────────────────────────────────────────────
160
+ def _predict_fallback(self, text: str) -> dict:
161
+ """
162
+ Run inference with the pretrained HuggingFace fallback model.
163
+
164
+ Hello-SimpleAI/chatgpt-detector-roberta returns:
165
+ {"label": "ChatGPT" | "Human", "score": float}
166
+ We normalise this to the same dict shape as _predict_custom.
167
+ """
168
+ try:
169
+ raw = self._fallback(text)[0] # {"label": ..., "score": ...}
170
+ except Exception as e:
171
+ return {"error": f"Fallback inference failed: {e}"}
172
+
173
+ hf_label = raw["label"].strip().lower() # "chatgpt" or "human"
174
+ score = float(raw["score"]) # confidence of the returned label
175
+
176
+ if hf_label in ("chatgpt", "ai", "fake", "generated"):
177
+ ai_prob = score
178
+ human_prob = 1.0 - score
179
+ label = "AI-Generated"
180
+ else:
181
+ human_prob = score
182
+ ai_prob = 1.0 - score
183
+ label = "Human-Written"
184
+
185
+ is_ai = ai_prob >= self.threshold
186
+ label = "AI-Generated" if is_ai else "Human-Written"
187
+ confidence = ai_prob if is_ai else human_prob
188
+
189
+ return {
190
+ "label": label,
191
+ "confidence": round(confidence, 4),
192
+ "ai_prob": round(ai_prob, 4),
193
+ "human_prob": round(human_prob, 4),
194
+ "source": "pretrained_fallback",
195
  }
196
 
197
+ # ──────────────────────────────────────────────────────────────────────────
198
+ def predict_batch(self, texts: list) -> list:
199
  """Run predict() on a list of texts. Returns list of result dicts."""
200
  return [self.predict(t) for t in texts]
201
 
202
+ # ─────���────────────────────────────────────────────────────────────────────
203
+ def format_for_gradio(self, text: str) -> tuple:
204
  """
205
+ Convenience wrapper returning Gradio-friendly values:
206
  (label_string, confidence_float, full_result_dict)
207
  """
208
  result = self.predict(text)
209
  if "error" in result:
210
  return result["error"], 0.0, result
211
+ return result["label"], result["confidence"], result