yaekobB commited on
Commit
65380c5
·
0 Parent(s):

Initial push to HF Spaces

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — Classify + Explain (Captum IG) — polished UX
2
+
3
+ # (Optional) silence common warnings on Windows/HF
4
+ import os
5
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
6
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
7
+
8
+ import json
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torch.nn as nn
13
+ import gradio as gr
14
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
15
+ from safetensors.torch import load_file
16
+ from captum.attr import LayerIntegratedGradients # explainability
17
+
18
+ # ----------------------------
19
+ # Paths / labels / config
20
+ # ----------------------------
21
+ ARTI_DIR = "artifacts"
22
+ BEST_DIR = os.path.join(ARTI_DIR, "best")
23
+ THRESH_FP = os.path.join(ARTI_DIR, "thresholds.json")
24
+
25
+ LABELS = ["toxic","severe_toxic","obscene","threat","insult","identity_hate"]
26
+ NUM_LABELS = len(LABELS)
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+ MAX_LEN = 256
29
+ BASE_MODEL = "distilbert-base-uncased" # same backbone as training
30
+
31
+ # ----------------------------
32
+ # Model definition (same logic)
33
+ # ----------------------------
34
+ class ToxicMultiLabel(nn.Module):
35
+ """
36
+ DistilBERT backbone + single linear head -> multi-label logits.
37
+ (We apply sigmoid at inference to get probabilities.)
38
+ """
39
+ def __init__(self, base_model_name: str, num_labels: int, head_dropout: float = 0.30):
40
+ super().__init__()
41
+ cfg = AutoConfig.from_pretrained(base_model_name)
42
+ self.backbone = AutoModel.from_pretrained(base_model_name, config=cfg)
43
+ hidden = self.backbone.config.hidden_size
44
+ self.dropout = nn.Dropout(head_dropout)
45
+ self.classifier = nn.Linear(hidden, num_labels)
46
+
47
+ def forward(self, input_ids=None, attention_mask=None):
48
+ out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
49
+ cls = out.last_hidden_state[:, 0] # [CLS]-like token
50
+ logits = self.classifier(self.dropout(cls)) # (B, L)
51
+ return logits
52
+
53
+ # ----------------------------
54
+ # Load artifacts (tokenizer, model, thresholds)
55
+ # ----------------------------
56
+ def load_artifacts():
57
+ # tokenizer (prefer the saved one if present)
58
+ tok_src = BEST_DIR if os.path.isfile(os.path.join(BEST_DIR, "tokenizer.json")) else BASE_MODEL
59
+ tok = AutoTokenizer.from_pretrained(tok_src, use_fast=True)
60
+
61
+ # model weights
62
+ model = ToxicMultiLabel(BASE_MODEL, NUM_LABELS)
63
+ safep = os.path.join(BEST_DIR, "model.safetensors")
64
+ binp = os.path.join(BEST_DIR, "pytorch_model.bin")
65
+
66
+ if os.path.isfile(safep):
67
+ state = load_file(safep)
68
+ elif os.path.isfile(binp):
69
+ state = torch.load(binp, map_location="cpu")
70
+ else:
71
+ raise FileNotFoundError("No weights found (model.safetensors / pytorch_model.bin) in artifacts/best/")
72
+
73
+ # strip training-only keys if any slipped in
74
+ for k in list(state.keys()):
75
+ if k.startswith("pos_weight") or k.startswith("loss_fn"):
76
+ state.pop(k, None)
77
+
78
+ model.load_state_dict(state, strict=True)
79
+ model.to(DEVICE).eval()
80
+
81
+ # thresholds
82
+ if os.path.isfile(THRESH_FP):
83
+ with open(THRESH_FP) as f:
84
+ thresholds = json.load(f)
85
+ else:
86
+ thresholds = {lab: 0.5 for lab in LABELS}
87
+ os.makedirs(ARTI_DIR, exist_ok=True)
88
+ with open(THRESH_FP, "w") as f:
89
+ json.dump(thresholds, f, indent=2)
90
+
91
+ return model, tok, thresholds
92
+
93
+ MODEL, TOK, THRESH = load_artifacts()
94
+
95
+ # =========================
96
+ # Inference (Classify tab)
97
+ # =========================
98
+ @torch.no_grad()
99
+ def classify_comment(text: str):
100
+ """
101
+ Returns: (DataFrame of per-label predictions, comma-separated positives)
102
+ """
103
+ text = (text or "").strip()
104
+ if not text:
105
+ return pd.DataFrame(columns=["label","probability","threshold","margin","decision"]), "(none)"
106
+
107
+ enc = TOK(text, truncation=True, padding=True, max_length=MAX_LEN, return_tensors="pt")
108
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
109
+ logits = MODEL(**enc).squeeze(0).detach().cpu().numpy()
110
+ probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid
111
+
112
+ rows = []
113
+ for i, lab in enumerate(LABELS):
114
+ p = float(probs[i])
115
+ t = float(THRESH.get(lab, 0.5))
116
+ rows.append({
117
+ "label": lab,
118
+ "probability": round(p, 4),
119
+ "threshold": round(t, 4),
120
+ "margin": round(p - t, 4),
121
+ "decision": "POS" if p >= t else "NEG",
122
+ })
123
+
124
+ df = pd.DataFrame(rows).sort_values(
125
+ ["decision", "margin", "probability"],
126
+ ascending=[False, False, False]
127
+ ).reset_index(drop=True)
128
+
129
+ positives = [r["label"] for r in rows if r["probability"] >= r["threshold"]]
130
+ return df, ", ".join(positives) if positives else "(none)"
131
+
132
+ # =========================
133
+ # Explainability (IG tab)
134
+ # =========================
135
+ # Layer IG on embedding layer
136
+ EMB_LAYER = MODEL.backbone.embeddings.word_embeddings
137
+
138
+ # Captum forward: single logit for chosen label
139
+ def _forward_for_label(input_ids, attention_mask, class_index: int):
140
+ logits = MODEL(input_ids=input_ids, attention_mask=attention_mask) # (B, L)
141
+ return logits[:, class_index]
142
+
143
+ LIG = LayerIntegratedGradients(_forward_for_label, EMB_LAYER)
144
+
145
+ def _tokenize_with_offsets(text: str):
146
+ return TOK(text, truncation=True, padding=True, max_length=MAX_LEN,
147
+ return_tensors="pt", return_offsets_mapping=True)
148
+
149
+ def _merge_wordpieces(tokens, offsets, scores):
150
+ """Merge WordPiece tokens (##subwords) into words; sum scores."""
151
+ words = []
152
+ for tok_piece, (start, end), sc in zip(tokens, offsets, scores):
153
+ # skip special tokens with (0,0) offsets
154
+ if (start, end) == (0, 0) and tok_piece.startswith("[") and tok_piece.endswith("]"):
155
+ continue
156
+ if tok_piece.startswith("##") and words:
157
+ words[-1]["text"] += tok_piece[2:]
158
+ words[-1]["end"] = end
159
+ words[-1]["score"] += float(sc)
160
+ else:
161
+ words.append({"text": tok_piece, "start": start, "end": end, "score": float(sc)})
162
+ return words
163
+
164
+ @torch.no_grad()
165
+ def _predict_probs(text: str):
166
+ enc = TOK(text, truncation=True, padding=True, max_length=MAX_LEN, return_tensors="pt")
167
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
168
+ logits = MODEL(**enc).squeeze(0).detach().cpu().numpy()
169
+ return 1.0 / (1.0 + np.exp(-logits)) # (L,)
170
+
171
+ def explain_comment(text: str, target_label: str, steps: int = 30):
172
+ """
173
+ Returns (HTML with colored spans, selected label prob as string).
174
+ Red = supports the label; Blue = opposes the label.
175
+ """
176
+ import html as ihtml
177
+
178
+ text = (text or "").strip()
179
+ if not text:
180
+ return "<i>Provide a comment to explain.</i>", "0.000"
181
+
182
+ idx = LABELS.index(target_label)
183
+ enc = _tokenize_with_offsets(text)
184
+ input_ids = enc["input_ids"].to(DEVICE)
185
+ attention_mask = enc["attention_mask"].to(DEVICE)
186
+ offsets = enc["offset_mapping"][0].tolist()
187
+ tokens = TOK.convert_ids_to_tokens(enc["input_ids"][0])
188
+
189
+ # PAD baseline
190
+ ref_ids = torch.full_like(input_ids, TOK.pad_token_id)
191
+
192
+ # Be robust to Captum return signature
193
+ res = LIG.attribute(
194
+ inputs=input_ids,
195
+ baselines=ref_ids,
196
+ additional_forward_args=(attention_mask, idx),
197
+ n_steps=int(max(4, steps)),
198
+ return_convergence_delta=True,
199
+ )
200
+ attributions = res[0] if isinstance(res, tuple) else res
201
+ token_attr = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()
202
+
203
+ pieces = _merge_wordpieces(tokens, offsets, token_attr)
204
+ arr = np.array([p["score"] for p in pieces], dtype=np.float32)
205
+ denom = float(np.max(np.abs(arr))) if np.max(np.abs(arr)) > 1e-8 else 1.0
206
+ for p in pieces:
207
+ p["score_norm"] = p["score"] / denom
208
+
209
+ def _color_for(s: float) -> str:
210
+ alpha = min(1.0, max(0.06, abs(s)))
211
+ return f"rgba(255,0,0,{alpha:.25f})" if s >= 0 else f"rgba(0,0,255,{alpha:.25f})"
212
+
213
+ out, last = "", 0
214
+ for p in pieces:
215
+ out += ihtml.escape(text[last:p["start"]])
216
+ out += (
217
+ f'<span title="score={p["score_norm"]:+.3f}" '
218
+ f'style="background:{_color_for(p["score_norm"])}; padding:1px 2px; border-radius:3px;">'
219
+ f'{ihtml.escape(text[p["start"]:p["end"]])}</span>'
220
+ )
221
+ last = p["end"]
222
+ out += ihtml.escape(text[last:])
223
+
224
+ probs = _predict_probs(text)
225
+ prob = float(probs[idx])
226
+ header = (
227
+ f"<h4 style='margin:6px 0;'>Label: <code>{target_label}</code> "
228
+ f"| Prob: {prob:.3f}</h4>"
229
+ "<div style='margin:4px 0 8px 0;'>Legend: "
230
+ "<span style='background:rgba(255,0,0,.25);padding:0 6px;'>supports</span> &nbsp; "
231
+ "<span style='background:rgba(0,0,255,.25);padding:0 6px;'>opposes</span></div>"
232
+ )
233
+ html_block = header + f"<div style='font-family:ui-sans-serif,system-ui;line-height:1.7;font-size:15px;'>{out}</div>"
234
+ return html_block, f"{prob:.3f}"
235
+
236
+ # =========================
237
+ # Gradio UI (shared textbox)
238
+ # =========================
239
+ EXAMPLES = [
240
+ "You are a complete idiot. Get banned already.",
241
+ "I will kill you tomorrow. Watch your back.",
242
+ "Thanks for your help—really appreciate your time!",
243
+ "Shut up, this is the dumbest edit ever.",
244
+ "Go away, you people don't belong here.",
245
+ ]
246
+
247
+ with gr.Blocks(
248
+ title="🧠 Toxic Comment Classifier & Explainer",
249
+ theme=gr.themes.Soft(primary_hue="blue")
250
+ ) as demo:
251
+ gr.Markdown(
252
+ f"""
253
+ # 🧠 Toxic Comment Classifier & Explainer
254
+ A DistilBERT-based **multi-label** model for detecting toxicity in online comments
255
+ with **Integrated Gradients** explanations (Captum).
256
+
257
+ **Device:** `{DEVICE}` &nbsp;&nbsp;•&nbsp;&nbsp; **Max length:** {MAX_LEN}
258
+ """
259
+ )
260
+
261
+ # Shared textbox (one input for both tabs)
262
+ txt = gr.Textbox(
263
+ label="Enter a comment",
264
+ lines=4,
265
+ value=EXAMPLES[1],
266
+ placeholder="Type or paste a comment here…"
267
+ )
268
+
269
+ with gr.Tab("🔍 Classify"):
270
+ btn = gr.Button("Classify", variant="primary")
271
+ out_tbl = gr.Dataframe(
272
+ headers=["label","probability","threshold","margin","decision"],
273
+ label="Per-label predictions",
274
+ interactive=False, wrap=True
275
+ )
276
+ out_pos = gr.Textbox(label="Predicted positive labels", interactive=False)
277
+ btn.click(classify_comment, inputs=txt, outputs=[out_tbl, out_pos])
278
+ gr.Examples(EXAMPLES, inputs=txt, label="Examples")
279
+
280
+ with gr.Tab("🧩 Explain"):
281
+ lab_dd = gr.Dropdown(choices=LABELS, value="toxic", label="Target label")
282
+ steps_slider = gr.Slider(6, 80, value=30, step=2,
283
+ label="IG steps (higher = smoother, slower)")
284
+ explain_btn = gr.Button("Generate explanation", variant="primary")
285
+ prob_box = gr.Textbox(label="Selected label probability", interactive=False)
286
+ html_vis = gr.HTML(label="Attribution heatmap")
287
+ explain_btn.click(
288
+ fn=explain_comment,
289
+ inputs=[txt, lab_dd, steps_slider], # shared text
290
+ outputs=[html_vis, prob_box]
291
+ )
292
+ gr.Examples(EXAMPLES, inputs=txt, label="Examples for Explain")
293
+
294
+ with gr.Accordion("ℹ️ About & Responsible Use", open=False):
295
+ gr.Markdown(
296
+ """
297
+ **Labels:** `toxic`, `severe_toxic`, `obscene`, `threat`, `insult`, `identity_hate`
298
+ This demo is for **research/education**. Do not use as-is for moderation without
299
+ human oversight, bias assessment, and policy alignment. Explanations
300
+ (IG attributions) are **heuristics**, not proof of model causality.
301
+ """
302
+ )
303
+
304
+ if __name__ == "__main__":
305
+ # For HF Spaces, you can use: demo.launch(share=False)
306
+ demo.launch(share=False)
artifacts/best/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e2ffa11b258be60ddd69d58763ed4c2c22a6311cf696ed433d50e4d148f63ca
3
+ size 265482136
artifacts/best/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
artifacts/best/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
artifacts/best/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "DistilBertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
artifacts/best/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcbb31fc410b121c1b246b8bea65f9898156d180516614a0d209692f89f3ad4e
3
+ size 5304
artifacts/best/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
artifacts/thresholds.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "toxic": 0.5,
3
+ "severe_toxic": 0.5,
4
+ "obscene": 0.5,
5
+ "threat": 0.5,
6
+ "insult": 0.5,
7
+ "identity_hate": 0.5
8
+ }
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+
4
+ ## 📦 **requirements.txt (final version)**
5
+ ```txt
6
+ torch
7
+ transformers
8
+ datasets
9
+ pandas
10
+ numpy
11
+ scikit-learn
12
+ seaborn
13
+ matplotlib
14
+ captum
15
+ gradio
16
+ safetensors
17
+ ftfy
18
+ iterative-stratification