mahmoudsaber0 commited on
Commit
e0b229b
·
verified ·
1 Parent(s): 54cad9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -71
app.py CHANGED
@@ -2,41 +2,35 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import re
5
- import matplotlib
6
- matplotlib.use("Agg")
7
  import matplotlib.pyplot as plt
8
  from tokenizers.normalizers import Sequence, Replace, Strip
9
  from tokenizers import Regex
10
 
11
- # -------------------------
12
- # Device setup
13
- # -------------------------
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
 
16
- # -------------------------
17
- # Model and Tokenizer Setup
18
- # -------------------------
19
  model1_path = "modernbert.bin"
20
  model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
21
  model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
22
 
23
  tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
24
 
25
- def safe_load_model(base_name, weights_path):
26
- model = AutoModelForSequenceClassification.from_pretrained(base_name, num_labels=41)
27
- state_dict = torch.hub.load_state_dict_from_url(weights_path, map_location=device) if weights_path.startswith("http") else torch.load(weights_path, map_location=device)
28
- model.load_state_dict(state_dict)
29
- model.to(device).eval()
30
- return model
31
-
32
- print("Loading models...")
33
- model_1 = safe_load_model("answerdotai/ModernBERT-base", model1_path)
34
- model_2 = safe_load_model("answerdotai/ModernBERT-base", model2_path)
35
- model_3 = safe_load_model("answerdotai/ModernBERT-base", model3_path)
36
-
37
- # -------------------------
38
- # Label Mapping
39
- # -------------------------
40
  label_mapping = {
41
  0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
42
  6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
@@ -50,9 +44,7 @@ label_mapping = {
50
  39: 'text-davinci-002', 40: 'text-davinci-003'
51
  }
52
 
53
- # -------------------------
54
- # Text Cleaning
55
- # -------------------------
56
  def clean_text(text: str) -> str:
57
  text = re.sub(r'\s{2,}', ' ', text)
58
  text = re.sub(r'\s+([,.;:?!])', r'\1', text)
@@ -65,21 +57,20 @@ tokenizer.backend_tokenizer.normalizer = Sequence([
65
  Strip()
66
  ])
67
 
68
- # -------------------------
69
- # Classification Function
70
- # -------------------------
71
  def classify_text(text):
72
  cleaned_text = clean_text(text)
73
  if not cleaned_text.strip():
74
- return "<b style='color:red;'>Please enter some text to analyze.</b>", None
75
 
 
76
  paragraphs = [p.strip() for p in re.split(r'\n{2,}', cleaned_text) if p.strip()]
77
  chunk_scores = []
78
  all_probabilities = []
79
 
80
  for i, paragraph in enumerate(paragraphs):
81
  inputs = tokenizer(paragraph, return_tensors="pt", truncation=True, padding=True).to(device)
82
-
83
  with torch.no_grad():
84
  logits_1 = model_1(**inputs).logits
85
  logits_2 = model_2(**inputs).logits
@@ -88,51 +79,50 @@ def classify_text(text):
88
  softmax_1 = torch.softmax(logits_1, dim=1)
89
  softmax_2 = torch.softmax(logits_2, dim=1)
90
  softmax_3 = torch.softmax(logits_3, dim=1)
91
- averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
92
- probabilities = averaged_probabilities[0]
93
- all_probabilities.append(probabilities.cpu())
94
 
95
- human_prob = probabilities[24].item()
96
- ai_probs_clone = probabilities.clone()
97
- ai_probs_clone[24] = 0
98
- ai_total_prob = ai_probs_clone.sum().item()
99
 
100
- total = human_prob + ai_total_prob
 
 
 
 
101
  human_pct = (human_prob / total) * 100
102
- ai_pct = (ai_total_prob / total) * 100
103
  ai_model = label_mapping[torch.argmax(ai_probs_clone).item()]
104
 
105
  chunk_scores.append({
106
- "paragraph": paragraph[:150] + ("..." if len(paragraph) > 150 else ""),
107
  "human": human_pct,
108
  "ai": ai_pct,
109
- "model": ai_model
 
110
  })
111
 
112
- # --- Overall ---
113
  avg_human = sum(c["human"] for c in chunk_scores) / len(chunk_scores)
114
  avg_ai = sum(c["ai"] for c in chunk_scores) / len(chunk_scores)
 
115
  if avg_human > avg_ai:
116
  result_message = f"**Overall Result:** <span class='highlight-human'>{avg_human:.2f}% Human-written</span>"
117
  else:
118
  top_model = max(chunk_scores, key=lambda c: c['ai'])['model']
119
  result_message = f"**Overall Result:** <span class='highlight-ai'>{avg_ai:.2f}% AI-generated (likely {top_model})</span>"
120
 
121
- # --- Paragraph Breakdown ---
122
- paragraph_html = "<h3>Paragraph Analysis:</h3>"
123
- for idx, c in enumerate(chunk_scores, 1):
124
- color = "#4CAF50" if c['human'] > c['ai'] else "#FF5733"
125
- paragraph_html += f"""
126
- <div style='margin-bottom:10px; border-left:4px solid {color}; padding-left:10px;'>
127
- <b>Paragraph {idx}</b>: {c['human']:.2f}% Human | {c['ai']:.2f}% AI → <i>{c['model']}</i><br>
128
- <small>{c['paragraph']}</small></div>
129
- """
130
-
131
- # --- Plot ---
132
  mean_probs = torch.mean(torch.stack(all_probabilities), dim=0)
133
- top_5_probs, top_5_indices = torch.topk(mean_probs, 5)
134
  top_5_probs = top_5_probs.cpu().numpy()
135
- top_5_labels = [label_mapping[i.item()] for i in top_5_indices]
136
 
137
  fig, ax = plt.subplots(figsize=(10, 5))
138
  bars = ax.barh(top_5_labels, top_5_probs, color='#4CAF50')
@@ -144,30 +134,31 @@ def classify_text(text):
144
  ax.text(width + 0.005, bar.get_y() + bar.get_height() / 2, f'{width:.2%}', va='center')
145
  plt.tight_layout()
146
 
147
- return result_message + "<br><br>" + paragraph_html, fig
148
 
149
 
150
- # -------------------------
151
- # UI Setup
152
- # -------------------------
153
  title = "AI Text Detector"
154
  description = """
155
- This tool uses <b>ModernBERT</b> to detect AI-generated text.<br>
156
  Each paragraph is analyzed separately to show which parts are likely AI-generated.
157
  """
158
  bottom_text = "**Developed by SzegedAI – Extended by Saber**"
159
 
160
  AI_texts = [
161
- "Artificial intelligence (AI) is reshaping industries by automating tasks, enhancing decision-making, and driving innovation. From predictive analytics in finance to autonomous vehicles in transportation, AI technologies are becoming integral to daily operations."
162
  ]
163
 
164
  Human_texts = [
165
- "Mathematics has always been a cornerstone of scientific discovery. It provides a precise language for describing natural phenomena, from the orbit of planets to the behavior of subatomic particles."
166
  ]
167
 
168
  iface = gr.Blocks(css="""
169
  @import url('https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400;700&display=swap');
170
- body { font-family: 'Roboto Mono', sans-serif !important; }
 
 
 
171
  .highlight-human { color: #4CAF50; font-weight: bold; }
172
  .highlight-ai { color: #FF5733; font-weight: bold; }
173
  """)
@@ -175,18 +166,14 @@ iface = gr.Blocks(css="""
175
  with iface:
176
  gr.Markdown(f"# {title}")
177
  gr.Markdown(description)
178
- text_input = gr.Textbox(label="", placeholder="Paste your article here...", lines=10)
179
- analyze_btn = gr.Button("🔍 Analyze", variant="primary")
180
- result_output = gr.HTML(label="Result")
181
  plot_output = gr.Plot(label="Model Probability Distribution")
182
-
183
- analyze_btn.click(classify_text, inputs=text_input, outputs=[result_output, plot_output])
184
-
185
  with gr.Tab("AI Examples"):
186
  gr.Examples(AI_texts, inputs=text_input)
187
  with gr.Tab("Human Examples"):
188
  gr.Examples(Human_texts, inputs=text_input)
189
-
190
  gr.Markdown(bottom_text)
191
 
192
  iface.launch(share=True)
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
  import re
 
 
5
  import matplotlib.pyplot as plt
6
  from tokenizers.normalizers import Sequence, Replace, Strip
7
  from tokenizers import Regex
8
 
9
+ # ---- Device setup ----
 
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ # ---- Model and Tokenizer Setup ----
 
 
13
  model1_path = "modernbert.bin"
14
  model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
15
  model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
16
 
17
  tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
18
 
19
+ # Load models
20
+ model_1 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
21
+ model_1.load_state_dict(torch.load(model1_path, map_location=device))
22
+ model_1.to(device).eval()
23
+
24
+ model_2 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
25
+ model_2.load_state_dict(torch.hub.load_state_dict_from_url(model2_path, map_location=device))
26
+ model_2.to(device).eval()
27
+
28
+ model_3 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
29
+ model_3.load_state_dict(torch.hub.load_state_dict_from_url(model3_path, map_location=device))
30
+ model_3.to(device).eval()
31
+
32
+
33
+ # ---- Label Mapping ----
34
  label_mapping = {
35
  0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
36
  6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
 
44
  39: 'text-davinci-002', 40: 'text-davinci-003'
45
  }
46
 
47
+ # ---- Text Cleaning ----
 
 
48
  def clean_text(text: str) -> str:
49
  text = re.sub(r'\s{2,}', ' ', text)
50
  text = re.sub(r'\s+([,.;:?!])', r'\1', text)
 
57
  Strip()
58
  ])
59
 
60
+
61
+ # ---- Classification Function ----
 
62
  def classify_text(text):
63
  cleaned_text = clean_text(text)
64
  if not cleaned_text.strip():
65
+ return "**Error:** Please enter some text to analyze.", None
66
 
67
+ # Split into paragraphs
68
  paragraphs = [p.strip() for p in re.split(r'\n{2,}', cleaned_text) if p.strip()]
69
  chunk_scores = []
70
  all_probabilities = []
71
 
72
  for i, paragraph in enumerate(paragraphs):
73
  inputs = tokenizer(paragraph, return_tensors="pt", truncation=True, padding=True).to(device)
 
74
  with torch.no_grad():
75
  logits_1 = model_1(**inputs).logits
76
  logits_2 = model_2(**inputs).logits
 
79
  softmax_1 = torch.softmax(logits_1, dim=1)
80
  softmax_2 = torch.softmax(logits_2, dim=1)
81
  softmax_3 = torch.softmax(logits_3, dim=1)
 
 
 
82
 
83
+ avg_probs = (softmax_1 + softmax_2 + softmax_3) / 3
84
+ probs = avg_probs[0]
85
+ all_probabilities.append(probs.cpu())
 
86
 
87
+ human_prob = probs[24].item()
88
+ ai_probs_clone = probs.clone()
89
+ ai_probs_clone[24] = 0
90
+ ai_total = ai_probs_clone.sum().item()
91
+ total = human_prob + ai_total
92
  human_pct = (human_prob / total) * 100
93
+ ai_pct = (ai_total / total) * 100
94
  ai_model = label_mapping[torch.argmax(ai_probs_clone).item()]
95
 
96
  chunk_scores.append({
 
97
  "human": human_pct,
98
  "ai": ai_pct,
99
+ "model": ai_model,
100
+ "text": paragraph[:200].replace('\n', ' ') + ("..." if len(paragraph) > 200 else "")
101
  })
102
 
103
+ # ---- Overall Averages ----
104
  avg_human = sum(c["human"] for c in chunk_scores) / len(chunk_scores)
105
  avg_ai = sum(c["ai"] for c in chunk_scores) / len(chunk_scores)
106
+
107
  if avg_human > avg_ai:
108
  result_message = f"**Overall Result:** <span class='highlight-human'>{avg_human:.2f}% Human-written</span>"
109
  else:
110
  top_model = max(chunk_scores, key=lambda c: c['ai'])['model']
111
  result_message = f"**Overall Result:** <span class='highlight-ai'>{avg_ai:.2f}% AI-generated (likely {top_model})</span>"
112
 
113
+ # ---- Paragraph Analysis (Markdown Clean) ----
114
+ paragraph_text = "\n\n**Paragraph Analysis:**\n"
115
+ for i, c in enumerate(chunk_scores, 1):
116
+ paragraph_text += (
117
+ f"**Paragraph {i}:** {c['human']:.2f}% Human | {c['ai']:.2f}% AI → *{c['model']}*\n"
118
+ f"{c['text']}\n\n"
119
+ )
120
+
121
+ # ---- Top 5 Models Plot ----
 
 
122
  mean_probs = torch.mean(torch.stack(all_probabilities), dim=0)
123
+ top_5_probs, top_5_idx = torch.topk(mean_probs, 5)
124
  top_5_probs = top_5_probs.cpu().numpy()
125
+ top_5_labels = [label_mapping[i.item()] for i in top_5_idx]
126
 
127
  fig, ax = plt.subplots(figsize=(10, 5))
128
  bars = ax.barh(top_5_labels, top_5_probs, color='#4CAF50')
 
134
  ax.text(width + 0.005, bar.get_y() + bar.get_height() / 2, f'{width:.2%}', va='center')
135
  plt.tight_layout()
136
 
137
+ return result_message + "\n\n" + paragraph_text, fig
138
 
139
 
140
+ # ---- UI Setup ----
 
 
141
  title = "AI Text Detector"
142
  description = """
143
+ This tool uses **ModernBERT** to detect AI-generated text.
144
  Each paragraph is analyzed separately to show which parts are likely AI-generated.
145
  """
146
  bottom_text = "**Developed by SzegedAI – Extended by Saber**"
147
 
148
  AI_texts = [
149
+ "Artificial intelligence (AI) is reshaping industries by automating tasks, enhancing decision-making, and driving innovation. From predictive analytics in finance to autonomous vehicles in transportation, AI technologies are becoming integral to daily operations. The future of AI lies not only in technological advancement but also in ensuring ethical use, transparency, and accountability."
150
  ]
151
 
152
  Human_texts = [
153
+ "Mathematics has always been a cornerstone of scientific discovery. It provides a precise language for describing natural phenomena, from the orbit of planets to the behavior of subatomic particles. The beauty of mathematics lies in its universality—its principles hold true regardless of context or culture."
154
  ]
155
 
156
  iface = gr.Blocks(css="""
157
  @import url('https://fonts.googleapis.com/css2?family=Roboto+Mono:wght@400;700&display=swap');
158
+ #text_input_box { border-radius: 10px; border: 2px solid #4CAF50; font-size: 18px; padding: 15px; margin-bottom: 20px; width: 60%; margin: auto; }
159
+ #result_output_box { border-radius: 10px; border: 2px solid #4CAF50; font-size: 16px; padding: 15px; margin-top: 20px; width: 80%; margin: auto; }
160
+ body { font-family: 'Roboto Mono', sans-serif !important; padding: 20px; }
161
+ .gradio-container { border: 1px solid #4CAF50; border-radius: 15px; padding: 30px; box-shadow: 0px 0px 10px rgba(0,255,0,0.4); max-width: 900px; margin: auto; }
162
  .highlight-human { color: #4CAF50; font-weight: bold; }
163
  .highlight-ai { color: #FF5733; font-weight: bold; }
164
  """)
 
166
  with iface:
167
  gr.Markdown(f"# {title}")
168
  gr.Markdown(description)
169
+ text_input = gr.Textbox(label="", placeholder="Paste your article here...", elem_id="text_input_box", lines=10)
170
+ result_output = gr.HTML("", elem_id="result_output_box")
 
171
  plot_output = gr.Plot(label="Model Probability Distribution")
172
+ text_input.change(classify_text, inputs=text_input, outputs=[result_output, plot_output])
 
 
173
  with gr.Tab("AI Examples"):
174
  gr.Examples(AI_texts, inputs=text_input)
175
  with gr.Tab("Human Examples"):
176
  gr.Examples(Human_texts, inputs=text_input)
 
177
  gr.Markdown(bottom_text)
178
 
179
  iface.launch(share=True)