zeyadusf commited on
Commit
c7fe332
1 Parent(s): 6b44742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -10
app.py CHANGED
@@ -12,6 +12,14 @@ class TextDetectionApp:
12
  self.roberta_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/roberta-DAIGT-kaggle")
13
  self.roberta_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/roberta-DAIGT-kaggle")
14
 
 
 
 
 
 
 
 
 
15
  # Load Feedforward model
16
  self.ff_model = torch.jit.load("model_scripted.pt")
17
 
@@ -91,7 +99,7 @@ class TextDetectionApp:
91
  with torch.no_grad():
92
  detection_score = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item()
93
  # Return result based on the score threshold
94
- return "Generated" if detection_score > 0.5 else "Human-Written"
95
 
96
  def classify_text(self, text, model_choice):
97
  """
@@ -99,7 +107,7 @@ class TextDetectionApp:
99
 
100
  Args:
101
  text (str): The input text to classify.
102
- model_choice (str): The model to use ('DeBERTa', 'RoBERTa', or 'Feedforward').
103
 
104
  Returns:
105
  str: The classification result.
@@ -114,9 +122,8 @@ class TextDetectionApp:
114
  # Get classification results
115
  logits = outputs.logits
116
  predicted_class_id = logits.argmax().item()
117
- label = "Generated" if predicted_class_id == 1 else "Human-Written"
118
- return f"DeBERTa Prediction: {label} (Class {predicted_class_id})"
119
-
120
  elif model_choice == 'RoBERTa':
121
  # Tokenize input
122
  inputs = self.roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -127,13 +134,37 @@ class TextDetectionApp:
127
  # Get classification results
128
  logits = outputs.logits
129
  predicted_class_id = logits.argmax().item()
130
- label = "Generated" if predicted_class_id == 1 else "Human-Written"
131
- return f"RoBERTa Prediction: {label} (Class {predicted_class_id})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  elif model_choice == 'Feedforward':
134
  # Run feedforward detection
135
  detection_result = self.detect_text(text)
136
- return f"Feedforward Detection: {detection_result}"
137
 
138
  else:
139
  return "Invalid model selection."
@@ -147,11 +178,11 @@ iface = gr.Interface(
147
  fn=app.classify_text,
148
  inputs=[
149
  gr.Textbox(lines=2, placeholder="Enter your text here..."),
150
- gr.Radio(choices=["DeBERTa", "RoBERTa", "Feedforward"], label="Model Choice")
151
  ],
152
  outputs="text",
153
  title="Text Classification with Multiple Models",
154
- description="Classify text as generated or human-written using DeBERTa, RoBERTa, or a custom Feedforward model."
155
  )
156
 
157
  iface.launch()
 
12
  self.roberta_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/roberta-DAIGT-kaggle")
13
  self.roberta_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/roberta-DAIGT-kaggle")
14
 
15
+ # Load BERT model and tokenizer
16
+ self.bert_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/bert-DAIGT-MODELS")
17
+ self.bert_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/bert-DAIGT-MODELS")
18
+
19
+ # Load DistilBERT model and tokenizer
20
+ self.distilbert_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/distilbert-DAIGT-MODELS")
21
+ self.distilbert_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/distilbert-DAIGT-MODELS")
22
+
23
  # Load Feedforward model
24
  self.ff_model = torch.jit.load("model_scripted.pt")
25
 
 
99
  with torch.no_grad():
100
  detection_score = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item()
101
  # Return result based on the score threshold
102
+ return "Generated Text" if detection_score > 0.5 else "Human-Written"
103
 
104
  def classify_text(self, text, model_choice):
105
  """
 
107
 
108
  Args:
109
  text (str): The input text to classify.
110
+ model_choice (str): The model to use ('DeBERTa', 'RoBERTa', 'BERT', 'DistilBERT', or 'Feedforward').
111
 
112
  Returns:
113
  str: The classification result.
 
122
  # Get classification results
123
  logits = outputs.logits
124
  predicted_class_id = logits.argmax().item()
125
+ label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
126
+ return f"{label} )"
 
127
  elif model_choice == 'RoBERTa':
128
  # Tokenize input
129
  inputs = self.roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
134
  # Get classification results
135
  logits = outputs.logits
136
  predicted_class_id = logits.argmax().item()
137
+ label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
138
+ return f"{label} )"
139
+ elif model_choice == 'BERT':
140
+ # Tokenize input
141
+ inputs = self.bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
142
+
143
+ # Run model
144
+ outputs = self.bert_model(**inputs)
145
+
146
+ # Get classification results
147
+ logits = outputs.logits
148
+ predicted_class_id = logits.argmax().item()
149
+ label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
150
+ return f"{label} )"
151
+ elif model_choice == 'DistilBERT':
152
+ # Tokenize input
153
+ inputs = self.distilbert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
154
+
155
+ # Run model
156
+ outputs = self.distilbert_model(**inputs)
157
+
158
+ # Get classification results
159
+ logits = outputs.logits
160
+ predicted_class_id = logits.argmax().item()
161
+ label = "Generated Text" if predicted_class_id == 1 else "Human-Written"
162
+ return f"{label} )"
163
 
164
  elif model_choice == 'Feedforward':
165
  # Run feedforward detection
166
  detection_result = self.detect_text(text)
167
+ return f"{detection_result}"
168
 
169
  else:
170
  return "Invalid model selection."
 
178
  fn=app.classify_text,
179
  inputs=[
180
  gr.Textbox(lines=2, placeholder="Enter your text here..."),
181
+ gr.Radio(choices=["DeBERTa", "RoBERTa", "BERT", "DistilBERT", "Feedforward"], label="Model Choice")
182
  ],
183
  outputs="text",
184
  title="Text Classification with Multiple Models",
185
+ description="Classify text as generated or human-written using DeBERTa, RoBERTa, BERT, DistilBERT, or a custom Feedforward model."
186
  )
187
 
188
  iface.launch()