Update text_detector_inference.py
Browse files- text_detector_inference.py +117 -37
text_detector_inference.py
CHANGED
|
@@ -2,68 +2,96 @@
|
|
| 2 |
text_detector_inference.py
|
| 3 |
==========================
|
| 4 |
Inference wrapper for HybridAITextDetector.
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
Usage
|
| 8 |
-----
|
| 9 |
from text_detector_inference import TextDetectorInference
|
| 10 |
|
| 11 |
-
detector = TextDetectorInference(
|
| 12 |
-
|
| 13 |
-
threshold=0.5
|
| 14 |
-
)
|
| 15 |
-
result = detector.predict("Some text here...")
|
| 16 |
"""
|
| 17 |
|
| 18 |
import os
|
| 19 |
import torch
|
| 20 |
-
from transformers import AutoTokenizer
|
| 21 |
from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class TextDetectorInference:
|
| 25 |
"""
|
| 26 |
-
Thin wrapper around HybridAITextDetector
|
|
|
|
| 27 |
|
| 28 |
Parameters
|
| 29 |
----------
|
| 30 |
checkpoint : str
|
| 31 |
-
Path to the .pt state-dict file.
|
| 32 |
threshold : float
|
| 33 |
Decision boundary for the sigmoid probability (default 0.5).
|
| 34 |
-
Set to the optimal F1 threshold found during
|
| 35 |
-
device : torch.device
|
| 36 |
Auto-detects CUDA if None.
|
| 37 |
"""
|
| 38 |
|
| 39 |
def __init__(
|
| 40 |
self,
|
| 41 |
-
checkpoint: str
|
| 42 |
threshold: float = 0.5,
|
| 43 |
device: torch.device = None,
|
| 44 |
):
|
| 45 |
-
self.threshold
|
| 46 |
-
self.device
|
| 47 |
"cuda" if torch.cuda.is_available() else "cpu"
|
| 48 |
)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
self.
|
|
|
|
| 52 |
|
| 53 |
if os.path.exists(checkpoint):
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
self.model = HybridAITextDetector()
|
| 56 |
self.model.load_state_dict(
|
| 57 |
torch.load(checkpoint, map_location=self.device)
|
| 58 |
)
|
| 59 |
self.model.eval().to(self.device)
|
| 60 |
-
|
|
|
|
| 61 |
else:
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def predict(self, text: str) -> dict:
|
| 68 |
"""
|
| 69 |
Classify a single text string.
|
|
@@ -72,17 +100,31 @@ class TextDetectorInference:
|
|
| 72 |
-------
|
| 73 |
dict with keys:
|
| 74 |
label : "AI-Generated" or "Human-Written"
|
| 75 |
-
confidence : probability of the predicted class (0
|
| 76 |
ai_prob : raw P(AI-generated)
|
| 77 |
human_prob : 1 - ai_prob
|
|
|
|
| 78 |
"""
|
| 79 |
-
if self.model is None:
|
| 80 |
-
return {"error": "Model not loaded β missing checkpoint file."}
|
| 81 |
-
|
| 82 |
text = text.strip()
|
| 83 |
if not text:
|
| 84 |
return {"error": "Input text is empty."}
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
enc = self.tokenizer(
|
| 87 |
text,
|
| 88 |
truncation=True,
|
|
@@ -90,7 +132,6 @@ class TextDetectorInference:
|
|
| 90 |
max_length=MAX_LENGTH,
|
| 91 |
return_tensors="pt",
|
| 92 |
)
|
| 93 |
-
|
| 94 |
input_ids = enc["input_ids"].to(self.device)
|
| 95 |
attention_mask = enc["attention_mask"].to(self.device)
|
| 96 |
token_type_ids = enc.get(
|
|
@@ -99,8 +140,8 @@ class TextDetectorInference:
|
|
| 99 |
).to(self.device)
|
| 100 |
|
| 101 |
with torch.no_grad():
|
| 102 |
-
logit
|
| 103 |
-
ai_prob
|
| 104 |
|
| 105 |
human_prob = 1.0 - ai_prob
|
| 106 |
is_ai = ai_prob >= self.threshold
|
|
@@ -112,20 +153,59 @@ class TextDetectorInference:
|
|
| 112 |
"confidence": round(confidence, 4),
|
| 113 |
"ai_prob": round(ai_prob, 4),
|
| 114 |
"human_prob": round(human_prob, 4),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
}
|
| 116 |
|
| 117 |
-
#
|
| 118 |
-
def predict_batch(self, texts: list
|
| 119 |
"""Run predict() on a list of texts. Returns list of result dicts."""
|
| 120 |
return [self.predict(t) for t in texts]
|
| 121 |
|
| 122 |
-
#
|
| 123 |
-
def format_for_gradio(self, text: str) -> tuple
|
| 124 |
"""
|
| 125 |
-
Convenience wrapper
|
| 126 |
(label_string, confidence_float, full_result_dict)
|
| 127 |
"""
|
| 128 |
result = self.predict(text)
|
| 129 |
if "error" in result:
|
| 130 |
return result["error"], 0.0, result
|
| 131 |
-
return result["label"], result["confidence"], result
|
|
|
|
| 2 |
text_detector_inference.py
|
| 3 |
==========================
|
| 4 |
Inference wrapper for HybridAITextDetector.
|
| 5 |
+
|
| 6 |
+
Strategy
|
| 7 |
+
--------
|
| 8 |
+
1. If ``best_text_detector.pt`` exists β load the custom trained model.
|
| 9 |
+
2. Otherwise β fall back to a pretrained
|
| 10 |
+
HuggingFace AI-text detector so the Space keeps working immediately.
|
| 11 |
|
| 12 |
Usage
|
| 13 |
-----
|
| 14 |
from text_detector_inference import TextDetectorInference
|
| 15 |
|
| 16 |
+
detector = TextDetectorInference() # auto-detects checkpoint
|
| 17 |
+
result = detector.predict("Some textβ¦")
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import os
|
| 21 |
import torch
|
| 22 |
+
from transformers import AutoTokenizer, pipeline as hf_pipeline
|
| 23 |
from text_detector_model import HybridAITextDetector, MODEL_NAME, MAX_LENGTH
|
| 24 |
|
| 25 |
+
# βββ Fallback model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
# Used when best_text_detector.pt is not present in the Space.
|
| 27 |
+
# "Hello-SimpleAI/chatgpt-detector-roberta" is a publicly available,
|
| 28 |
+
# well-validated AI-text detector (RoBERTa fine-tuned on ChatGPT outputs).
|
| 29 |
+
FALLBACK_MODEL_ID = "Hello-SimpleAI/chatgpt-detector-roberta"
|
| 30 |
+
|
| 31 |
|
| 32 |
class TextDetectorInference:
|
| 33 |
"""
|
| 34 |
+
Thin wrapper around HybridAITextDetector (or a fallback pretrained model)
|
| 35 |
+
for single-text prediction.
|
| 36 |
|
| 37 |
Parameters
|
| 38 |
----------
|
| 39 |
checkpoint : str
|
| 40 |
+
Path to the .pt state-dict file for the custom model.
|
| 41 |
threshold : float
|
| 42 |
Decision boundary for the sigmoid probability (default 0.5).
|
| 43 |
+
Set to the optimal F1 threshold found during your training run.
|
| 44 |
+
device : torch.device | None
|
| 45 |
Auto-detects CUDA if None.
|
| 46 |
"""
|
| 47 |
|
| 48 |
def __init__(
|
| 49 |
self,
|
| 50 |
+
checkpoint: str = "best_text_detector.pt",
|
| 51 |
threshold: float = 0.5,
|
| 52 |
device: torch.device = None,
|
| 53 |
):
|
| 54 |
+
self.threshold = threshold
|
| 55 |
+
self.device = device or torch.device(
|
| 56 |
"cuda" if torch.cuda.is_available() else "cpu"
|
| 57 |
)
|
| 58 |
+
self._use_custom = False
|
| 59 |
+
self._fallback = None
|
| 60 |
+
self.model = None
|
| 61 |
+
self.tokenizer = None
|
| 62 |
|
| 63 |
if os.path.exists(checkpoint):
|
| 64 |
+
# ββ Load custom trained HybridAITextDetector ββββββββββββββββββββββ
|
| 65 |
+
print(f"[TextDetector] β
Checkpoint found: {checkpoint}")
|
| 66 |
+
print(f"[TextDetector] Loading tokenizer from {MODEL_NAME} β¦")
|
| 67 |
+
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 68 |
self.model = HybridAITextDetector()
|
| 69 |
self.model.load_state_dict(
|
| 70 |
torch.load(checkpoint, map_location=self.device)
|
| 71 |
)
|
| 72 |
self.model.eval().to(self.device)
|
| 73 |
+
self._use_custom = True
|
| 74 |
+
print("[TextDetector] β
Custom model ready")
|
| 75 |
else:
|
| 76 |
+
# ββ Fall back to pretrained HuggingFace model βββββββββββββββββββββ
|
| 77 |
+
print(
|
| 78 |
+
f"[TextDetector] β οΈ '{checkpoint}' not found.\n"
|
| 79 |
+
f"[TextDetector] Loading pretrained fallback: {FALLBACK_MODEL_ID}"
|
| 80 |
+
)
|
| 81 |
+
try:
|
| 82 |
+
self._fallback = hf_pipeline(
|
| 83 |
+
"text-classification",
|
| 84 |
+
model=FALLBACK_MODEL_ID,
|
| 85 |
+
device=0 if torch.cuda.is_available() else -1,
|
| 86 |
+
truncation=True,
|
| 87 |
+
max_length=512,
|
| 88 |
+
)
|
| 89 |
+
print(f"[TextDetector] β
Fallback model ready ({FALLBACK_MODEL_ID})")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"[TextDetector] β Fallback model failed to load: {e}")
|
| 92 |
+
self._fallback = None
|
| 93 |
+
|
| 94 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½ββββββββββββββββββββ
|
| 95 |
def predict(self, text: str) -> dict:
|
| 96 |
"""
|
| 97 |
Classify a single text string.
|
|
|
|
| 100 |
-------
|
| 101 |
dict with keys:
|
| 102 |
label : "AI-Generated" or "Human-Written"
|
| 103 |
+
confidence : probability of the predicted class (0β1)
|
| 104 |
ai_prob : raw P(AI-generated)
|
| 105 |
human_prob : 1 - ai_prob
|
| 106 |
+
source : "custom_model" | "pretrained_fallback"
|
| 107 |
"""
|
|
|
|
|
|
|
|
|
|
| 108 |
text = text.strip()
|
| 109 |
if not text:
|
| 110 |
return {"error": "Input text is empty."}
|
| 111 |
|
| 112 |
+
if self._use_custom:
|
| 113 |
+
return self._predict_custom(text)
|
| 114 |
+
elif self._fallback is not None:
|
| 115 |
+
return self._predict_fallback(text)
|
| 116 |
+
else:
|
| 117 |
+
return {
|
| 118 |
+
"error": (
|
| 119 |
+
"No model available. Upload 'best_text_detector.pt' to the "
|
| 120 |
+
"Space, or check your internet connection so the fallback "
|
| 121 |
+
"model can be downloaded."
|
| 122 |
+
)
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
def _predict_custom(self, text: str) -> dict:
|
| 127 |
+
"""Run inference with the custom HybridAITextDetector checkpoint."""
|
| 128 |
enc = self.tokenizer(
|
| 129 |
text,
|
| 130 |
truncation=True,
|
|
|
|
| 132 |
max_length=MAX_LENGTH,
|
| 133 |
return_tensors="pt",
|
| 134 |
)
|
|
|
|
| 135 |
input_ids = enc["input_ids"].to(self.device)
|
| 136 |
attention_mask = enc["attention_mask"].to(self.device)
|
| 137 |
token_type_ids = enc.get(
|
|
|
|
| 140 |
).to(self.device)
|
| 141 |
|
| 142 |
with torch.no_grad():
|
| 143 |
+
logit = self.model(input_ids, attention_mask, token_type_ids)
|
| 144 |
+
ai_prob = torch.sigmoid(logit).item()
|
| 145 |
|
| 146 |
human_prob = 1.0 - ai_prob
|
| 147 |
is_ai = ai_prob >= self.threshold
|
|
|
|
| 153 |
"confidence": round(confidence, 4),
|
| 154 |
"ai_prob": round(ai_prob, 4),
|
| 155 |
"human_prob": round(human_prob, 4),
|
| 156 |
+
"source": "custom_model",
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 160 |
+
def _predict_fallback(self, text: str) -> dict:
|
| 161 |
+
"""
|
| 162 |
+
Run inference with the pretrained HuggingFace fallback model.
|
| 163 |
+
|
| 164 |
+
Hello-SimpleAI/chatgpt-detector-roberta returns:
|
| 165 |
+
{"label": "ChatGPT" | "Human", "score": float}
|
| 166 |
+
We normalise this to the same dict shape as _predict_custom.
|
| 167 |
+
"""
|
| 168 |
+
try:
|
| 169 |
+
raw = self._fallback(text)[0] # {"label": ..., "score": ...}
|
| 170 |
+
except Exception as e:
|
| 171 |
+
return {"error": f"Fallback inference failed: {e}"}
|
| 172 |
+
|
| 173 |
+
hf_label = raw["label"].strip().lower() # "chatgpt" or "human"
|
| 174 |
+
score = float(raw["score"]) # confidence of the returned label
|
| 175 |
+
|
| 176 |
+
if hf_label in ("chatgpt", "ai", "fake", "generated"):
|
| 177 |
+
ai_prob = score
|
| 178 |
+
human_prob = 1.0 - score
|
| 179 |
+
label = "AI-Generated"
|
| 180 |
+
else:
|
| 181 |
+
human_prob = score
|
| 182 |
+
ai_prob = 1.0 - score
|
| 183 |
+
label = "Human-Written"
|
| 184 |
+
|
| 185 |
+
is_ai = ai_prob >= self.threshold
|
| 186 |
+
label = "AI-Generated" if is_ai else "Human-Written"
|
| 187 |
+
confidence = ai_prob if is_ai else human_prob
|
| 188 |
+
|
| 189 |
+
return {
|
| 190 |
+
"label": label,
|
| 191 |
+
"confidence": round(confidence, 4),
|
| 192 |
+
"ai_prob": round(ai_prob, 4),
|
| 193 |
+
"human_prob": round(human_prob, 4),
|
| 194 |
+
"source": "pretrained_fallback",
|
| 195 |
}
|
| 196 |
|
| 197 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 198 |
+
def predict_batch(self, texts: list) -> list:
|
| 199 |
"""Run predict() on a list of texts. Returns list of result dicts."""
|
| 200 |
return [self.predict(t) for t in texts]
|
| 201 |
|
| 202 |
+
# βββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 203 |
+
def format_for_gradio(self, text: str) -> tuple:
|
| 204 |
"""
|
| 205 |
+
Convenience wrapper returning Gradio-friendly values:
|
| 206 |
(label_string, confidence_float, full_result_dict)
|
| 207 |
"""
|
| 208 |
result = self.predict(text)
|
| 209 |
if "error" in result:
|
| 210 |
return result["error"], 0.0, result
|
| 211 |
+
return result["label"], result["confidence"], result
|