zeyadusf commited on
Commit
6b44742
1 Parent(s): dbee1f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -86,11 +86,12 @@ class TextDetectionApp:
86
  Detects whether the input text is generated or human-written using the Feedforward model.
87
 
88
  Returns:
89
- float: The detection result.
90
  """
91
  with torch.no_grad():
92
- self.output = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item()
93
- return self.output
 
94
 
95
  def classify_text(self, text, model_choice):
96
  """
@@ -113,7 +114,8 @@ class TextDetectionApp:
113
  # Get classification results
114
  logits = outputs.logits
115
  predicted_class_id = logits.argmax().item()
116
- return f"DeBERTa Prediction: Class {predicted_class_id}"
 
117
 
118
  elif model_choice == 'RoBERTa':
119
  # Tokenize input
@@ -125,12 +127,13 @@ class TextDetectionApp:
125
  # Get classification results
126
  logits = outputs.logits
127
  predicted_class_id = logits.argmax().item()
128
- return f"RoBERTa Prediction: Class {predicted_class_id}"
 
129
 
130
  elif model_choice == 'Feedforward':
131
  # Run feedforward detection
132
- detection_score = self.detect_text(text)
133
- return f"Feedforward Detection Score: {detection_score}"
134
 
135
  else:
136
  return "Invalid model selection."
@@ -148,7 +151,7 @@ iface = gr.Interface(
148
  ],
149
  outputs="text",
150
  title="Text Classification with Multiple Models",
151
- description="Classify text using DeBERTa, RoBERTa, or a custom Feedforward model."
152
  )
153
 
154
  iface.launch()
 
86
  Detects whether the input text is generated or human-written using the Feedforward model.
87
 
88
  Returns:
89
+ str: The detection result indicating if the text is generated or human-written.
90
  """
91
  with torch.no_grad():
92
+ detection_score = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item()
93
+ # Return result based on the score threshold
94
+ return "Generated" if detection_score > 0.5 else "Human-Written"
95
 
96
  def classify_text(self, text, model_choice):
97
  """
 
114
  # Get classification results
115
  logits = outputs.logits
116
  predicted_class_id = logits.argmax().item()
117
+ label = "Generated" if predicted_class_id == 1 else "Human-Written"
118
+ return f"DeBERTa Prediction: {label} (Class {predicted_class_id})"
119
 
120
  elif model_choice == 'RoBERTa':
121
  # Tokenize input
 
127
  # Get classification results
128
  logits = outputs.logits
129
  predicted_class_id = logits.argmax().item()
130
+ label = "Generated" if predicted_class_id == 1 else "Human-Written"
131
+ return f"RoBERTa Prediction: {label} (Class {predicted_class_id})"
132
 
133
  elif model_choice == 'Feedforward':
134
  # Run feedforward detection
135
+ detection_result = self.detect_text(text)
136
+ return f"Feedforward Detection: {detection_result}"
137
 
138
  else:
139
  return "Invalid model selection."
 
151
  ],
152
  outputs="text",
153
  title="Text Classification with Multiple Models",
154
+ description="Classify text as generated or human-written using DeBERTa, RoBERTa, or a custom Feedforward model."
155
  )
156
 
157
  iface.launch()