|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
class TextDetectionApp: |
|
def __init__(self): |
|
|
|
self.deberta_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/deberta-DAIGT-MODELS") |
|
self.deberta_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/deberta-DAIGT-MODELS") |
|
|
|
|
|
self.roberta_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/roberta-DAIGT-kaggle") |
|
self.roberta_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/roberta-DAIGT-kaggle") |
|
|
|
|
|
self.bert_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/bert-DAIGT-MODELS") |
|
self.bert_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/bert-DAIGT-MODELS") |
|
|
|
|
|
self.distilbert_tokenizer = AutoTokenizer.from_pretrained("zeyadusf/distilbert-DAIGT-MODELS") |
|
self.distilbert_model = AutoModelForSequenceClassification.from_pretrained("zeyadusf/distilbert-DAIGT-MODELS") |
|
|
|
|
|
self.ff_model = torch.jit.load("model_scripted.pt") |
|
|
|
def api_huggingface(self, text): |
|
""" |
|
Generate predictions using the DeBERTa and RoBERTa models. |
|
|
|
Args: |
|
text (str): The input text to classify. |
|
|
|
Returns: |
|
tuple: Predictions from RoBERTa and DeBERTa models. |
|
""" |
|
|
|
deberta_inputs = self.deberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
deberta_outputs = self.deberta_model(**deberta_inputs) |
|
deberta_logits = deberta_outputs.logits |
|
deberta_scores = torch.softmax(deberta_logits, dim=1) |
|
deberta_predictions = [ |
|
{"label": f"LABEL_{i}", "score": score.item()} |
|
for i, score in enumerate(deberta_scores[0]) |
|
] |
|
|
|
|
|
roberta_inputs = self.roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
roberta_outputs = self.roberta_model(**roberta_inputs) |
|
roberta_logits = roberta_outputs.logits |
|
roberta_scores = torch.softmax(roberta_logits, dim=1) |
|
roberta_predictions = [ |
|
{"label": f"LABEL_{i}", "score": score.item()} |
|
for i, score in enumerate(roberta_scores[0]) |
|
] |
|
|
|
return roberta_predictions, deberta_predictions |
|
|
|
def generate_ff_input(self, models_results): |
|
""" |
|
Generates input features for the Feedforward model from the API output. |
|
|
|
Parameters: |
|
models_results (tuple): Tuple containing the results of DeBERTa and RoBERTa models. |
|
|
|
Returns: |
|
torch.Tensor: Feedforward model input features tensor. |
|
""" |
|
roberta, deberta = models_results |
|
input_ff = [] |
|
try: |
|
if roberta[0]['label'] == 'LABEL_0': |
|
input_ff.append(roberta[0]['score']) |
|
input_ff.append(roberta[1]['score']) |
|
else: |
|
input_ff.append(roberta[1]['score']) |
|
input_ff.append(roberta[0]['score']) |
|
|
|
if deberta[0]['label'] == 'LABEL_0': |
|
input_ff.append(deberta[0]['score']) |
|
input_ff.append(deberta[1]['score']) |
|
else: |
|
input_ff.append(deberta[1]['score']) |
|
input_ff.append(deberta[0]['score']) |
|
|
|
except Exception as e: |
|
print(f"Error {e}: The text is long") |
|
|
|
input_ff = torch.tensor(input_ff, dtype=torch.float32) |
|
input_ff = input_ff.view(1, -1) |
|
return input_ff |
|
|
|
def detect_text(self, text): |
|
""" |
|
Detects whether the input text is generated or human-written using the Feedforward model. |
|
|
|
Returns: |
|
str: The detection result indicating if the text is generated or human-written. |
|
""" |
|
with torch.no_grad(): |
|
detection_score = self.ff_model(self.generate_ff_input(self.api_huggingface(text)))[0][0].item() |
|
|
|
return "Generated Text" if detection_score > 0.5 else "Human-Written" |
|
|
|
def classify_text(self, text, model_choice): |
|
""" |
|
Classifies the input text using the selected model. |
|
|
|
Args: |
|
text (str): The input text to classify. |
|
model_choice (str): The model to use ('DeBERTa', 'RoBERTa', 'BERT', 'DistilBERT', or 'Feedforward'). |
|
|
|
Returns: |
|
str: The classification result. |
|
""" |
|
if model_choice == 'DeBERTa': |
|
|
|
inputs = self.deberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
outputs = self.deberta_model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
predicted_class_id = logits.argmax().item() |
|
label = "Generated Text" if predicted_class_id == 1 else "Human-Written" |
|
return f"{label} )" |
|
elif model_choice == 'RoBERTa': |
|
|
|
inputs = self.roberta_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
outputs = self.roberta_model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
predicted_class_id = logits.argmax().item() |
|
label = "Generated Text" if predicted_class_id == 1 else "Human-Written" |
|
return f"{label} )" |
|
elif model_choice == 'BERT': |
|
|
|
inputs = self.bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
outputs = self.bert_model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
predicted_class_id = logits.argmax().item() |
|
label = "Generated Text" if predicted_class_id == 1 else "Human-Written" |
|
return f"{label} )" |
|
elif model_choice == 'DistilBERT': |
|
|
|
inputs = self.distilbert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
outputs = self.distilbert_model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
predicted_class_id = logits.argmax().item() |
|
label = "Generated Text" if predicted_class_id == 1 else "Human-Written" |
|
return f"{label} )" |
|
|
|
elif model_choice == 'Feedforward': |
|
|
|
detection_result = self.detect_text(text) |
|
return f"{detection_result}" |
|
|
|
else: |
|
return "Invalid model selection." |
|
|
|
|
|
|
|
app = TextDetectionApp() |
|
|
|
|
|
iface = gr.Interface( |
|
fn=app.classify_text, |
|
inputs=[ |
|
gr.Textbox(lines=2, placeholder="Enter your text here..."), |
|
gr.Radio(choices=["DeBERTa", "RoBERTa", "BERT", "DistilBERT", "Feedforward"], label="Model Choice") |
|
], |
|
outputs="text", |
|
title="Text Classification with Multiple Models", |
|
description="Classify text as generated or human-written using DeBERTa, RoBERTa, BERT, DistilBERT, or a custom Feedforward model." |
|
) |
|
|
|
iface.launch() |
|
|