Update app.py
Browse files
app.py
CHANGED
|
@@ -1,48 +1,44 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
"""
|
| 4 |
-
Gradio App —
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import os
|
| 15 |
import io
|
| 16 |
import re
|
| 17 |
-
from typing import Dict, Any
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
import torch
|
| 21 |
import gradio as gr
|
| 22 |
-
|
| 23 |
-
from transformers import (
|
| 24 |
-
AutoTokenizer,
|
| 25 |
-
AutoModelForSequenceClassification,
|
| 26 |
-
)
|
| 27 |
|
| 28 |
# -----------------------------
|
| 29 |
# Config
|
| 30 |
# -----------------------------
|
| 31 |
-
MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased")
|
| 32 |
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
|
| 33 |
STRIDE = int(os.getenv("STRIDE", "128"))
|
| 34 |
|
| 35 |
-
# Device
|
| 36 |
device = torch.device("cuda" if torch.cuda.is_available() else
|
| 37 |
"mps" if torch.backends.mps.is_available() else "cpu")
|
| 38 |
-
|
| 39 |
if device.type == "mps":
|
| 40 |
try:
|
| 41 |
torch.set_float32_matmul_precision("high")
|
| 42 |
except Exception:
|
| 43 |
pass
|
| 44 |
|
| 45 |
-
# Load model & tokenizer
|
| 46 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
|
| 47 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
|
| 48 |
model.eval()
|
|
@@ -50,7 +46,6 @@ model.eval()
|
|
| 50 |
# -----------------------------
|
| 51 |
# Utilities
|
| 52 |
# -----------------------------
|
| 53 |
-
|
| 54 |
TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
|
| 55 |
PDF_EXTS = {".pdf"}
|
| 56 |
|
|
@@ -87,7 +82,7 @@ def read_text_from_file(file_obj) -> str:
|
|
| 87 |
except Exception as e:
|
| 88 |
return f"[PDF parse error] {e}"
|
| 89 |
|
| 90 |
-
# Fallback: try
|
| 91 |
data = file_obj.read()
|
| 92 |
if isinstance(data, bytes):
|
| 93 |
data = data.decode("utf-8", errors="ignore")
|
|
@@ -96,8 +91,8 @@ def read_text_from_file(file_obj) -> str:
|
|
| 96 |
|
| 97 |
def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
|
| 98 |
"""
|
| 99 |
-
Chunk the document using tokenizer overflow, run
|
| 100 |
-
|
| 101 |
"""
|
| 102 |
if not text or not text.strip():
|
| 103 |
return {"error": "Empty document."}
|
|
@@ -122,27 +117,32 @@ def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: st
|
|
| 122 |
out = model(**batch)
|
| 123 |
logits_list.append(out.logits)
|
| 124 |
|
| 125 |
-
logits = torch.cat(logits_list, dim=0)
|
| 126 |
-
probs
|
| 127 |
num_chunks = int(probs.shape[0])
|
| 128 |
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
all_scores = {id2label.get(i, str(i)): float(doc_probs[i]) for i in range(len(doc_probs))}
|
| 136 |
|
| 137 |
return {
|
| 138 |
-
"
|
| 139 |
-
"
|
| 140 |
-
"all_scores": all_scores,
|
| 141 |
"num_chunks": num_chunks,
|
| 142 |
-
"
|
|
|
|
| 143 |
"stride": stride,
|
| 144 |
-
"model": MODEL_ID,
|
| 145 |
-
"device": str(device),
|
| 146 |
}
|
| 147 |
|
| 148 |
|
|
@@ -153,10 +153,9 @@ def predict_from_upload(file, aggregation, max_length, stride):
|
|
| 153 |
# Work around gradio temp file behavior
|
| 154 |
if hasattr(file, "name") and isinstance(file.name, str):
|
| 155 |
with open(file.name, "rb") as f:
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
text = read_text_from_file(mem)
|
| 160 |
else:
|
| 161 |
text = read_text_from_file(file)
|
| 162 |
|
|
@@ -164,36 +163,109 @@ def predict_from_upload(file, aggregation, max_length, stride):
|
|
| 164 |
|
| 165 |
|
| 166 |
# -----------------------------
|
| 167 |
-
#
|
| 168 |
# -----------------------------
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
|
| 176 |
-
"""
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
btn.click(
|
| 192 |
-
fn=
|
| 193 |
-
inputs=[file_in,
|
| 194 |
-
outputs=[
|
| 195 |
-
api_name="predict",
|
| 196 |
)
|
| 197 |
|
| 198 |
if __name__ == "__main__":
|
| 199 |
-
demo.launch()
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
"""
|
| 4 |
+
Gradio App — AI vs Human Document Classifier (Chunked Inference)
|
| 5 |
+
----------------------------------------------------------------
|
| 6 |
+
Features:
|
| 7 |
+
- Upload a document (TXT/MD/HTML/PDF), chunk if needed, classify each chunk, aggregate to document.
|
| 8 |
+
- Shows:
|
| 9 |
+
1) Probability bars with raw numbers (AI generated / Human written)
|
| 10 |
+
2) Confidence badge ("Likely AI" / "Likely Human") with traffic-light color
|
| 11 |
+
3) Tabs for Basic / Advanced controls
|
| 12 |
+
4) Chunk details accordion with per-chunk probabilities
|
| 13 |
"""
|
| 14 |
|
| 15 |
import os
|
| 16 |
import io
|
| 17 |
import re
|
| 18 |
+
from typing import Dict, Any, List, Tuple
|
| 19 |
|
| 20 |
import numpy as np
|
| 21 |
import torch
|
| 22 |
import gradio as gr
|
| 23 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# -----------------------------
|
| 26 |
# Config
|
| 27 |
# -----------------------------
|
| 28 |
+
MODEL_ID = os.getenv("MODEL_ID", "bert-base-uncased") # e.g., "username/bert-binclass"
|
| 29 |
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
|
| 30 |
STRIDE = int(os.getenv("STRIDE", "128"))
|
| 31 |
|
| 32 |
+
# Device
|
| 33 |
device = torch.device("cuda" if torch.cuda.is_available() else
|
| 34 |
"mps" if torch.backends.mps.is_available() else "cpu")
|
|
|
|
| 35 |
if device.type == "mps":
|
| 36 |
try:
|
| 37 |
torch.set_float32_matmul_precision("high")
|
| 38 |
except Exception:
|
| 39 |
pass
|
| 40 |
|
| 41 |
+
# Load model & tokenizer
|
| 42 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
|
| 43 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, torch_dtype=torch.float32).to(device)
|
| 44 |
model.eval()
|
|
|
|
| 46 |
# -----------------------------
|
| 47 |
# Utilities
|
| 48 |
# -----------------------------
|
|
|
|
| 49 |
TEXT_EXTS = {".txt", ".md", ".rtf", ".html", ".htm"}
|
| 50 |
PDF_EXTS = {".pdf"}
|
| 51 |
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
return f"[PDF parse error] {e}"
|
| 84 |
|
| 85 |
+
# Fallback: try as text
|
| 86 |
data = file_obj.read()
|
| 87 |
if isinstance(data, bytes):
|
| 88 |
data = data.decode("utf-8", errors="ignore")
|
|
|
|
| 91 |
|
| 92 |
def chunked_predict(text: str, max_length: int = 512, stride: int = 128, agg: str = "mean") -> Dict[str, Any]:
|
| 93 |
"""
|
| 94 |
+
Chunk the document using tokenizer overflow, run classifier on each chunk,
|
| 95 |
+
aggregate probabilities, and return both doc-level and chunk-level results.
|
| 96 |
"""
|
| 97 |
if not text or not text.strip():
|
| 98 |
return {"error": "Empty document."}
|
|
|
|
| 117 |
out = model(**batch)
|
| 118 |
logits_list.append(out.logits)
|
| 119 |
|
| 120 |
+
logits = torch.cat(logits_list, dim=0) # [num_chunks, num_labels]
|
| 121 |
+
probs = torch.softmax(logits, dim=-1).cpu().numpy()
|
| 122 |
num_chunks = int(probs.shape[0])
|
| 123 |
|
| 124 |
+
# Aggregate
|
| 125 |
+
if agg == "max":
|
| 126 |
+
doc_probs = probs.max(axis=0)
|
| 127 |
+
else:
|
| 128 |
+
doc_probs = probs.mean(axis=0)
|
| 129 |
+
|
| 130 |
+
# By convention: 0 -> Human, 1 -> AI
|
| 131 |
+
prob_human = float(doc_probs[0])
|
| 132 |
+
prob_ai = float(doc_probs[1])
|
| 133 |
|
| 134 |
+
# Per-chunk table rows
|
| 135 |
+
chunk_rows = []
|
| 136 |
+
for i, p in enumerate(probs):
|
| 137 |
+
chunk_rows.append([i + 1, float(p[1]), float(p[0])]) # [chunk, AI, Human]
|
|
|
|
| 138 |
|
| 139 |
return {
|
| 140 |
+
"ai_prob": prob_ai,
|
| 141 |
+
"human_prob": prob_human,
|
|
|
|
| 142 |
"num_chunks": num_chunks,
|
| 143 |
+
"chunk_rows": chunk_rows, # list of [chunk, AI, Human]
|
| 144 |
+
"max_length": max_length,
|
| 145 |
"stride": stride,
|
|
|
|
|
|
|
| 146 |
}
|
| 147 |
|
| 148 |
|
|
|
|
| 153 |
# Work around gradio temp file behavior
|
| 154 |
if hasattr(file, "name") and isinstance(file.name, str):
|
| 155 |
with open(file.name, "rb") as f:
|
| 156 |
+
raw = io.BytesIO(f.read())
|
| 157 |
+
raw.name = os.path.basename(file.name)
|
| 158 |
+
text = read_text_from_file(raw)
|
|
|
|
| 159 |
else:
|
| 160 |
text = read_text_from_file(file)
|
| 161 |
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
# -----------------------------
|
| 166 |
+
# UI Helpers (HTML formatting)
|
| 167 |
# -----------------------------
|
| 168 |
+
def probability_bar_html(label: str, prob: float) -> str:
|
| 169 |
+
"""Return an HTML row with label, percent, and a bar."""
|
| 170 |
+
pct = prob * 100.0
|
| 171 |
+
return f"""
|
| 172 |
+
<div class="prob-row"><div class="prob-label"><b>{label}</b></div>
|
| 173 |
+
<div class="prob-value">{pct:.2f}%</div>
|
| 174 |
+
<div class="prob-bar">
|
| 175 |
+
<div class="prob-fill" style="width:{pct:.2f}%"></div>
|
| 176 |
+
</div>
|
| 177 |
+
</div>
|
| 178 |
+
"""
|
| 179 |
|
| 180 |
+
def verdict_badge_html(prob_ai: float, threshold: float = 0.5) -> str:
|
| 181 |
+
label = "Likely AI" if prob_ai >= threshold else "Likely Human"
|
| 182 |
+
color = "#ef4444" if prob_ai >= threshold else "#10b981" # red / green
|
| 183 |
+
return f"<span class='pill' style='background:{color}22;color:{color}'>{label}</span>"
|
| 184 |
|
| 185 |
+
def format_outputs(result: Dict[str, Any], threshold: float = 0.5):
|
| 186 |
+
"""Produce (verdict_html, probs_html, chunk_table_data, details_md)."""
|
| 187 |
+
if "error" in result:
|
| 188 |
+
return f"<span style='color:#ef4444'>{result['error']}</span>", "", [], ""
|
| 189 |
|
| 190 |
+
ai, human = result["ai_prob"], result["human_prob"]
|
| 191 |
+
verdict_html = verdict_badge_html(ai, threshold=threshold)
|
| 192 |
|
| 193 |
+
probs_html = ""
|
| 194 |
+
probs_html += probability_bar_html("AI generated", ai)
|
| 195 |
+
probs_html += probability_bar_html("Human written", human)
|
| 196 |
+
|
| 197 |
+
# Chunk table rows
|
| 198 |
+
table_data = result["chunk_rows"]
|
| 199 |
+
|
| 200 |
+
details_md = (
|
| 201 |
+
f"**Chunks:** `{result['num_chunks']}` \n"
|
| 202 |
+
f"**Tokens per chunk:** `{result['max_length']}` \n"
|
| 203 |
+
f"**Stride:** `{result['stride']}`"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
return verdict_html, probs_html, table_data, details_md
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# -----------------------------
|
| 210 |
+
# Gradio Interface
|
| 211 |
+
# -----------------------------
|
| 212 |
+
CSS = """
|
| 213 |
+
.pill {padding:6px 12px; border-radius:999px; display:inline-block; margin: 6px 0; font-weight:600;}
|
| 214 |
+
.prob-row {display:flex; align-items:center; gap:10px; margin:6px 0;}
|
| 215 |
+
.prob-label {min-width:140px;}
|
| 216 |
+
.prob-value {min-width:80px; text-align:right; font-variant-numeric: tabular-nums;}
|
| 217 |
+
.prob-bar {flex:1; background:#e5e7eb; height:12px; border-radius:6px; overflow:hidden;}
|
| 218 |
+
.prob-fill {height:12px; background:#6366f1;}
|
| 219 |
+
.small-note {font-size:0.9rem; color:#6b7280;}
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
DESCRIPTION = """
|
| 223 |
+
### 🔎 AI vs Human — Document Classifier
|
| 224 |
+
Upload a file to get **document-level probabilities**.
|
| 225 |
+
Long inputs are **chunked** into overlapping windows; chunk predictions are **aggregated**.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
with gr.Blocks(
|
| 229 |
+
title="AI vs Human Document Classifier",
|
| 230 |
+
theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate"),
|
| 231 |
+
css=CSS
|
| 232 |
+
) as demo:
|
| 233 |
+
gr.Markdown(DESCRIPTION)
|
| 234 |
|
| 235 |
+
with gr.Tabs():
|
| 236 |
+
with gr.Tab("Predict"):
|
| 237 |
+
file_in = gr.File(label="Upload a document", file_types=[".txt", ".md", ".rtf", ".html", ".htm", ".pdf"])
|
| 238 |
+
agg_in = gr.Radio(choices=["mean", "max"], value="mean", label="Aggregation over chunks")
|
| 239 |
+
btn = gr.Button("Predict", variant="primary")
|
| 240 |
+
verdict_html = gr.HTML(label="Verdict")
|
| 241 |
+
probs_html = gr.HTML(label="Probabilities")
|
| 242 |
+
|
| 243 |
+
with gr.Accordion("Chunk details", open=False):
|
| 244 |
+
chunk_table = gr.Dataframe(
|
| 245 |
+
headers=["Chunk", "AI generated", "Human written"],
|
| 246 |
+
datatype=["number", "number", "number"],
|
| 247 |
+
label="Per-chunk probabilities",
|
| 248 |
+
wrap=True,
|
| 249 |
+
interactive=False,
|
| 250 |
+
height=240
|
| 251 |
+
)
|
| 252 |
+
details_md = gr.Markdown("", elem_classes=["small-note"])
|
| 253 |
+
|
| 254 |
+
with gr.Tab("Advanced"):
|
| 255 |
+
gr.Markdown("Adjust chunking parameters below.")
|
| 256 |
+
max_len_in = gr.Slider(128, 1024, value=MAX_LENGTH, step=32, label="Tokens per chunk (max_length)")
|
| 257 |
+
stride_in = gr.Slider(0, 512, value=STRIDE, step=16, label="Stride / overlap")
|
| 258 |
+
gr.Markdown("You can also set `MODEL_ID`, `MAX_LENGTH`, and `STRIDE` via Space Variables.")
|
| 259 |
+
|
| 260 |
+
def predict_and_prettify(file, aggregation, max_length=MAX_LENGTH, stride=STRIDE):
|
| 261 |
+
res = predict_from_upload(file, aggregation, max_length, stride)
|
| 262 |
+
return format_outputs(res)
|
| 263 |
|
| 264 |
btn.click(
|
| 265 |
+
fn=predict_and_prettify,
|
| 266 |
+
inputs=[file_in, agg_in, max_len_in, stride_in],
|
| 267 |
+
outputs=[verdict_html, probs_html, chunk_table, details_md],
|
|
|
|
| 268 |
)
|
| 269 |
|
| 270 |
if __name__ == "__main__":
|
| 271 |
+
demo.launch()
|