zeyadusf commited on
Commit
2a3ba09
1 Parent(s): 42b6795

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -18
app.py CHANGED
@@ -99,7 +99,7 @@ class TextDetectionApp:
99
  with torch.no_grad():
100
  detection_score = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item()
101
  # Return result based on the score threshold
102
- return "Generated Text" if detection_score > 0.5 else "Human-Written"
103
 
104
  def classify_text(self, text, model_choice):
105
  """
@@ -110,7 +110,7 @@ class TextDetectionApp:
110
  model_choice (str): The model to use ('DeBERTa', 'RoBERTa', 'BERT', 'DistilBERT', or 'Feedforward').
111
 
112
  Returns:
113
- str: The classification result.
114
  """
115
  if model_choice == 'DeBERTa':
116
  # Tokenize input
@@ -121,9 +121,12 @@ class TextDetectionApp:
121
 
122
  # Get classification results
123
  logits = outputs.logits
124
- predicted_class_id = logits.argmax().item()
125
- label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
126
- return f"{label} )"
 
 
 
127
  elif model_choice == 'RoBERTa':
128
  # Tokenize input
129
  inputs = self.roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -133,9 +136,12 @@ class TextDetectionApp:
133
 
134
  # Get classification results
135
  logits = outputs.logits
136
- predicted_class_id = logits.argmax().item()
137
- label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
138
- return f"{label} )"
 
 
 
139
  elif model_choice == 'BERT':
140
  # Tokenize input
141
  inputs = self.bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -145,9 +151,12 @@ class TextDetectionApp:
145
 
146
  # Get classification results
147
  logits = outputs.logits
148
- predicted_class_id = logits.argmax().item()
149
- label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
150
- return f"{label} )"
 
 
 
151
  elif model_choice == 'DistilBERT':
152
  # Tokenize input
153
  inputs = self.distilbert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -157,14 +166,19 @@ class TextDetectionApp:
157
 
158
  # Get classification results
159
  logits = outputs.logits
160
- predicted_class_id = logits.argmax().item()
161
- label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
162
- return f"{label} )"
 
 
163
 
164
- elif model_choice == 'Feedforward':
165
  # Run feedforward detection
166
- detection_result = self.detect_text(text)
167
- return f"{detection_result}"
 
 
 
168
 
169
  else:
170
  return "Invalid model selection."
@@ -182,7 +196,7 @@ iface = gr.Interface(
182
  ],
183
  outputs="text",
184
  title="Text Classification with Multiple Models",
185
- description="Classify text as generated or human-written using DeBERTa, RoBERTa, BERT, DistilBERT, or a custom Feedforward model."
186
  )
187
 
188
  iface.launch()
 
99
  with torch.no_grad():
100
  detection_score = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item()
101
  # Return result based on the score threshold
102
+ return detection_score
103
 
104
  def classify_text(self, text, model_choice):
105
  """
 
110
  model_choice (str): The model to use ('DeBERTa', 'RoBERTa', 'BERT', 'DistilBERT', or 'Feedforward').
111
 
112
  Returns:
113
+ str: The classification result including prediction scores.
114
  """
115
  if model_choice == 'DeBERTa':
116
  # Tokenize input
 
121
 
122
  # Get classification results
123
  logits = outputs.logits
124
+ scores = torch.softmax(logits, dim=1)[0]
125
+ generated_score = scores[1].item()
126
+ human_written_score = scores[0].item()
127
+ label = "Generated Text" if generated_score > 0.5 else "Human-Written"
128
+ return f"{label} ({generated_score * 100:.2f}% Generated, {human_written_score * 100:.2f}% Human-Written)"
129
+
130
  elif model_choice == 'RoBERTa':
131
  # Tokenize input
132
  inputs = self.roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
136
 
137
  # Get classification results
138
  logits = outputs.logits
139
+ scores = torch.softmax(logits, dim=1)[0]
140
+ generated_score = scores[1].item()
141
+ human_written_score = scores[0].item()
142
+ label = "Generated Text" if generated_score > 0.5 else "Human-Written"
143
+ return f"{label} ({generated_score * 100:.2f}% Generated, {human_written_score * 100:.2f}% Human-Written)"
144
+
145
  elif model_choice == 'BERT':
146
  # Tokenize input
147
  inputs = self.bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
151
 
152
  # Get classification results
153
  logits = outputs.logits
154
+ scores = torch.softmax(logits, dim=1)[0]
155
+ generated_score = scores[1].item()
156
+ human_written_score = scores[0].item()
157
+ label = "Generated Text" if generated_score > 0.5 else "Human-Written"
158
+ return f"{label} ({generated_score * 100:.2f}% Generated, {human_written_score * 100:.2f}% Human-Written)"
159
+
160
  elif model_choice == 'DistilBERT':
161
  # Tokenize input
162
  inputs = self.distilbert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
166
 
167
  # Get classification results
168
  logits = outputs.logits
169
+ scores = torch.softmax(logits, dim=1)[0]
170
+ generated_score = scores[1].item()
171
+ human_written_score = scores[0].item()
172
+ label = "Generated Text" if generated_score > 0.5 else "Human-Written"
173
+ return f"{label} ({generated_score * 100:.2f}% Generated, {human_written_score * 100:.2f}% Human-Written)"
174
 
175
+ elif model_choice == 'DAIGT-Model':
176
  # Run feedforward detection
177
+ detection_score = self.detect_text(text)
178
+ label = "Generated Text" if detection_score > 0.5 else "Human-Written"
179
+ generated_score = detection_score
180
+ human_written_score = 1 - detection_score
181
+ return f"{label} ({generated_score * 100:.2f}% Generated, {human_written_score * 100:.2f}% Human-Written)"
182
 
183
  else:
184
  return "Invalid model selection."
 
196
  ],
197
  outputs="text",
198
  title="Text Classification with Multiple Models",
199
+ description="Classify text as generated or human-written using DeBERTa, RoBERTa, BERT, DistilBERT, or a custom Feedforward model. See the confidence percentages for each prediction."
200
  )
201
 
202
  iface.launch()