ai-assist-sh commited on
Commit
f9467b7
·
verified ·
1 Parent(s): 755ffe2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -63
main.py CHANGED
@@ -1,8 +1,9 @@
1
- import os, re, time, json, tempfile
2
  import gradio as gr
3
  import torch
4
  import torch.nn.functional as F
5
 
 
6
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
7
 
8
  URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
@@ -70,19 +71,20 @@ def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, t
70
  md.append("```txt\n" + cls_preview + "\n```")
71
  return "\n".join(md)
72
 
73
- def analyze(text: str, forensic: bool, forensics_json: str):
74
  """
75
- Returns:
76
- - Markdown body
77
- - Updated forensics_json (string)
 
78
  """
79
  text = (text or "").strip()
80
  if not text:
81
- return "Paste an email body or a URL.", ""
82
 
83
  urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
84
  if not urls:
85
- return "No URLs detected in the text.", ""
86
 
87
  tok, mdl = _load_model()
88
  id2label_raw = getattr(mdl.config, "id2label", None) or {}
@@ -119,9 +121,9 @@ def analyze(text: str, forensic: bool, forensics_json: str):
119
  out = mdl(**enc, output_hidden_states=True)
120
  elapsed = time.time() - t0
121
 
122
- logits = out.logits.squeeze(0)
123
- probs = _softmax(logits)
124
- hidden_states = out.hidden_states
125
  cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
126
 
127
  per_class = [
@@ -141,7 +143,7 @@ def analyze(text: str, forensic: bool, forensics_json: str):
141
  "truncated": truncated,
142
  "logits": [float(x) for x in logits.cpu().tolist()],
143
  "probs": [float(p) for p in probs],
144
- "scores_sorted": per_class_sorted,
145
  "cls_vector": cls_vec,
146
  "cls_dim": len(cls_vec),
147
  "elapsed_sec": elapsed,
@@ -162,61 +164,30 @@ def analyze(text: str, forensic: bool, forensics_json: str):
162
 
163
  verdict = "🔴 **UNSAFE (links flagged)**" if unsafe else "🟢 **SAFE (all links benign)**"
164
  body = verdict + "\n\n" + _markdown_table(rows)
 
165
  if forensic and forensic_blocks:
166
  body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
167
 
168
- # Return JSON string (not dict) to avoid schema bug
169
- return body, json.dumps(export_data, ensure_ascii=False)
170
-
171
- def export_forensics(forensics_json: str):
172
- """Write the JSON string to a file and return the path."""
173
- if not forensics_json:
174
- return None
175
- try:
176
- data = json.loads(forensics_json)
177
- if not isinstance(data, dict) or not data.get("items"):
178
- return None
179
- except Exception:
180
- return None
181
- fd, path = tempfile.mkstemp(prefix="forensics_", suffix=".json")
182
- with os.fdopen(fd, "w", encoding="utf-8") as f:
183
- f.write(forensics_json)
184
- return path
185
-
186
- with gr.Blocks() as demo:
187
- gr.Markdown("# 🛡️ PhishingMail — Forensics (Tokens, Logits, CLS)")
188
- gr.Markdown(
189
- "Paste an **email body** or a **URL**. We extract links and classify each with a compact malicious-URL model. "
190
- "Enable **Forensic mode** to show tokens, logits, and the **[CLS] embedding**. "
191
- "Use **Export** to download full forensics as JSON."
192
- )
193
-
194
- with gr.Row():
195
- inp = gr.Textbox(lines=6, label="Email or URL", placeholder="Paste a URL or a full email…")
196
- forensic_chk = gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=False)
197
-
198
- # Hidden storage for forensics JSON (string)
199
- forensics_json_store = gr.Textbox(value="", visible=False)
200
-
201
- with gr.Row():
202
- btn_analyze = gr.Button("Analyze", variant="primary")
203
- btn_export = gr.Button("Export forensics (JSON)")
204
-
205
- out_md = gr.Markdown(label="Results")
206
- out_file = gr.File(label="Download forensics JSON", interactive=False)
207
-
208
- btn_analyze.click(
209
- analyze,
210
- inputs=[inp, forensic_chk, forensics_json_store],
211
- outputs=[out_md, forensics_json_store],
212
- show_progress=True,
213
- )
214
- btn_export.click(
215
- export_forensics,
216
- inputs=[forensics_json_store],
217
- outputs=[out_file],
218
- )
219
 
220
  if __name__ == "__main__":
221
- # Extra-safe config for HF Spaces
222
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
+ import os, re, time, json
2
  import gradio as gr
3
  import torch
4
  import torch.nn.functional as F
5
 
6
+ # Be quiet + CPU friendly
7
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
8
 
9
  URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
 
71
  md.append("```txt\n" + cls_preview + "\n```")
72
  return "\n".join(md)
73
 
74
+ def analyze(text: str, forensic: bool, show_json: bool):
75
  """
76
+ Returns a single Markdown block:
77
+ - verdict + compact table
78
+ - optional forensic blocks (tokens, logits, CLS)
79
+ - optional raw JSON (copy/paste)
80
  """
81
  text = (text or "").strip()
82
  if not text:
83
+ return "Paste an email body or a URL."
84
 
85
  urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
86
  if not urls:
87
+ return "No URLs detected in the text."
88
 
89
  tok, mdl = _load_model()
90
  id2label_raw = getattr(mdl.config, "id2label", None) or {}
 
121
  out = mdl(**enc, output_hidden_states=True)
122
  elapsed = time.time() - t0
123
 
124
+ logits = out.logits.squeeze(0) # (num_labels,)
125
+ probs = _softmax(logits) # list[float]
126
+ hidden_states = out.hidden_states # tuple of layers
127
  cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
128
 
129
  per_class = [
 
143
  "truncated": truncated,
144
  "logits": [float(x) for x in logits.cpu().tolist()],
145
  "probs": [float(p) for p in probs],
146
+ "scores_sorted": per_class_sorted, # label+prob+logit
147
  "cls_vector": cls_vec,
148
  "cls_dim": len(cls_vec),
149
  "elapsed_sec": elapsed,
 
164
 
165
  verdict = "🔴 **UNSAFE (links flagged)**" if unsafe else "🟢 **SAFE (all links benign)**"
166
  body = verdict + "\n\n" + _markdown_table(rows)
167
+
168
  if forensic and forensic_blocks:
169
  body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
170
 
171
+ if show_json:
172
+ # raw JSON for copy-paste (no File component needed)
173
+ pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
174
+ body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
175
+ body += "```json\n" + pretty + "\n```"
176
+
177
+ return body
178
+
179
+ demo = gr.Interface(
180
+ fn=analyze,
181
+ inputs=[
182
+ gr.Textbox(lines=6, label="Email or URL", placeholder="Paste a URL or a full email…"),
183
+ gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
184
+ gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
185
+ ],
186
+ outputs=gr.Markdown(label="Results"),
187
+ title="🛡️ PhishingMail — Forensics (HF Free CPU)",
188
+ description="Extract links, classify with a tiny URL model, and (optionally) view tokens, logits, and [CLS] embedding.",
189
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  if __name__ == "__main__":
192
+ # Safe defaults for HF Spaces (no share=True needed)
193
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)