Spaces:
Running
Running
import gradio as gr | |
from transformers import DebertaTokenizer, DebertaForSequenceClassification, DistilBertTokenizer, DistilBertForSequenceClassification | |
from transformers import pipeline | |
import json | |
import numpy as np | |
import random | |
save_path_abstract = './fine-tuned-distillberta' | |
model_abstract = DistilBertForSequenceClassification.from_pretrained(save_path_abstract) | |
tokenizer_abstract = DistilBertTokenizer.from_pretrained(save_path_abstract) | |
classifier_abstract = pipeline('text-classification', model=model_abstract, tokenizer=tokenizer_abstract) | |
save_path_essay = './fine-tuned-distillberta' | |
model_essay = DistilBertForSequenceClassification.from_pretrained(save_path_essay) | |
tokenizer_essay = DistilBertTokenizer.from_pretrained(save_path_essay) | |
classifier_essay = pipeline('text-classification', model=model_essay, tokenizer=tokenizer_essay) | |
demo_essays = json.load(open('samples.json')) | |
index = None | |
################# HELPER FUNCTIONS (DETECTION TAB) #################### | |
def process_result_detection_tab(text): | |
''' | |
Classify the text into one of the four categories by averaging the soft predictions of the two models. | |
Args: | |
text: str: the text to be classified | |
Returns: | |
dict: a dictionary with the following keys: | |
'Machine Generated': float: the probability that the text is machine generated | |
'Human Written': float: the probability that the text is human written | |
'Machine Written, Machine Humanized': float: the probability that the text is machine written and machine humanized | |
'Human Written, Machine Polished': float: the probability that the text is human written and machine polished | |
''' | |
mapping = {'llm': 'Machine Generated', 'human':'Human Written', 'machine-humanized': 'Machine Written, Machine Humanized', 'machine-polished': 'Human Written, Machine Polished'} | |
result = classifier_abstract(text) | |
result_r = classifier_essay(text) | |
labels = [mapping[x['label']] for x in result] | |
scores = list(0.5 * np.array([x['score'] for x in result]) + 0.5 * np.array([x['score'] for x in result_r])) | |
final_results = dict(zip(labels, scores)) | |
print(final_results) | |
return final_results | |
def update_detection_tab(name, uploaded_file, radio_input): | |
''' | |
Callback function to update the result of the classification based on the input text or uploaded file. | |
Args: | |
name: str: the input text from the Textbox | |
uploaded_file: file: the uploaded file from the file input | |
Returns: | |
dict: the result of the classification including labels and scores | |
''' | |
if name == '' and uploaded_file is None: | |
return "" | |
if uploaded_file is not None: | |
return f"Work in progress" | |
else: | |
return process_result_detection_tab(name) | |
def active_button_detection_tab(input_text, file_input): | |
''' | |
Callback function to activate the 'Check Origin' button when the input text or file input | |
is not empty. For text input, the button can be clickde only when the word count is between | |
50 and 500. | |
Args: | |
input_text: str: the input text from the textbox | |
file_input: file: the uploaded file from the file input | |
Returns: | |
gr.Button: The 'Check Origin' button with the appropriate interactivity. | |
''' | |
if (input_text == "" and file_input is None) or (file_input is None and not (50 <= len(input_text.split()) <= 500)): | |
return gr.Button("Check Origin", variant="primary", interactive=False) | |
return gr.Button("Check Origin", variant="primary", interactive=True) | |
def clear_detection_tab(): | |
''' | |
Callback function to clear the input text and file input in the 'Try it!' tab. | |
The interactivity of the 'Check Origin' button is set to False to prevent user click when the Textbox is empty. | |
Args: | |
None | |
Returns: | |
str: An empty string to clear the Textbox. | |
None: None to clear the file input. | |
gr.Button: The 'Check Origin' button with no interactivity. | |
''' | |
return "", None, gr.Button("Check Origin", variant="primary", interactive=False) | |
def count_words_detection_tab(text): | |
''' | |
Callback function called when the input text is changed to update the word count. | |
Args: | |
text: str: the input text from the Textbox | |
Returns: | |
str: the word count of the input text for the Markdown widget | |
''' | |
return (f'{len(text.split())}/500 words (Minimum 50 words)') | |
################# HELPER FUNCTIONS (CHALLENGE TAB) #################### | |
def clear_challenge_tab(): | |
''' | |
Callback function to clear the text and result in the 'Challenge Yourself' tab. | |
The interactivity of the buttons is set to False to prevent user click when the Textbox is empty. | |
Args: | |
None | |
Returns: | |
gr.Button: The 'Machine-Generated' button with no interactivity. | |
gr.Button: The 'Human-Written' button with no interactivity. | |
gr.Button: The 'Machine-Humanized' button with no interactivity. | |
gr.Button: The 'Machine-Polished' button with no interactivity. | |
str: An empty string to clear the Textbox. | |
''' | |
mg = gr.Button("Machine-Generated", variant="secondary", interactive=False) | |
hw = gr.Button("Human-Written", variant="secondary", interactive=False) | |
mh = gr.Button("Machine-Humanized", variant="secondary", interactive=False) | |
mp = gr.Button("Machine-Polished", variant="secondary", interactive=False) | |
return mg, hw, mh, mp, '' | |
def generate_text_challenge_tab(): | |
''' | |
Callback function to randomly sample an essay from the dataset and set the interactivity of the buttons to True. | |
Args: | |
None | |
Returns: | |
str: A sample text from the dataset | |
gr.Button: The 'Machine-Generated' button with interactivity. | |
gr.Button: The 'Human-Written' button with interactivity. | |
gr.Button: The 'Machine-Humanized' button with interactivity. | |
gr.Button: The 'Machine-Polished' button with interactivity. | |
str: An empty string to clear the Result. | |
''' | |
global index # to access the index of the sample text for the show_result function | |
mg = gr.Button("Machine-Generated", variant="secondary", interactive=True) | |
hw = gr.Button("Human-Written", variant="secondary", interactive=True) | |
mh = gr.Button("Machine-Humanized", variant="secondary", interactive=True) | |
mp = gr.Button("Machine-Polished", variant="secondary", interactive=True) | |
index = random.choice(range(80)) | |
essay = demo_essays[index][0] | |
return essay, mg, hw, mh, mp, '' | |
def correct_label_challenge_tab(): | |
''' | |
Function to return the correct label of the sample text based on the index (global variable). | |
Args: | |
None | |
Returns: | |
str: The correct label of the sample text | |
''' | |
if 0 <= index < 20 : | |
return 'Human-Written' | |
elif 20 <= index < 40: | |
return 'Machine-Generated' | |
elif 40 <= index < 60: | |
return 'Machine-Polished' | |
elif 60 <= index < 80: | |
return 'Machine-Humanized' | |
def show_result_challenge_tab(button): | |
''' | |
Callback function to show the result of the classification based on the button clicked by the user. | |
The correct label of the sample text is displayed in the primary variant. | |
The chosen label by the user is displayed in the stop variant if it is incorrect. | |
Args: | |
button: str: the label of the button clicked by the user | |
Returns: | |
str: the outcome of the classification | |
gr.Button: The 'Machine-Generated' button with the appropriate variant. | |
gr.Button: The 'Human-Written' button with the appropriate variant. | |
gr.Button: The 'Machine-Humanized' button with the appropriate variant. | |
gr.Button: The 'Machine-Polished' button with the appropriate variant. | |
''' | |
correct_btn = correct_label_challenge_tab() | |
mg = gr.Button("Machine-Generated", variant="secondary") | |
hw = gr.Button("Human-Written", variant="secondary") | |
mh = gr.Button("Machine-Humanized", variant="secondary") | |
mp = gr.Button("Machine-Polished", variant="secondary") | |
if button == 'Machine-Generated': | |
mg = gr.Button("Machine-Generated", variant="stop") | |
elif button == 'Human-Written': | |
hw = gr.Button("Human-Written", variant="stop") | |
elif button == 'Machine-Humanized': | |
mh = gr.Button("Machine-Humanized", variant="stop") | |
elif button == 'Machine-Polished': | |
mp = gr.Button("Machine-Polished", variant="stop") | |
if correct_btn == 'Machine-Generated': | |
mg = gr.Button("Machine-Generated", variant="primary") | |
elif correct_btn == 'Human-Written': | |
hw = gr.Button("Human-Written", variant="primary") | |
elif correct_btn == 'Machine-Humanized': | |
mh = gr.Button("Machine-Humanized", variant="primary") | |
elif correct_btn == 'Machine-Polished': | |
mp = gr.Button("Machine-Polished", variant="primary") | |
outcome = '' | |
if button == correct_btn: | |
outcome = 'Correct' | |
else: | |
outcome = 'Incorrect' | |
return outcome, mg, hw, mh, mp | |
############################## GRADIO UI ############################## | |
with gr.Blocks() as demo: | |
gr.Markdown("""<h1><centre>Machine Generated Text (MGT) Detection</center></h1>""") | |
with gr.Tab('Try it!'): | |
with gr.Row(): | |
radio_button = gr.Dropdown(['Student Essay', 'Scientific Abstract'], label = 'Text Type', info = 'We have specialized models that work on domain-specific text.', value='Student Essay') | |
with gr.Row(): | |
input_text = gr.Textbox(placeholder="Paste your text here...", label="Text", lines=10, max_lines=15) | |
file_input = gr.File(label="Upload File", file_types=[".txt", ".pdf"]) | |
with gr.Row(): | |
wc = gr.Markdown("0/500 words (Minimum 50 words)") | |
with gr.Row(): | |
check_button = gr.Button("Check Origin", variant="primary", interactive=False) | |
clear_button = gr.ClearButton([input_text, file_input], variant="stop") | |
out = gr.Label(label='Result') | |
clear_button.add(out) | |
check_button.click(fn=update_detection_tab, inputs=[input_text, file_input, radio_button], outputs=out) | |
input_text.change(count_words_detection_tab, input_text, wc, show_progress=False) | |
input_text.input( | |
active_button_detection_tab, | |
[input_text, file_input], | |
[check_button], | |
) | |
file_input.upload( | |
active_button_detection_tab, | |
[input_text, file_input], | |
[check_button], | |
) | |
clear_button.click( | |
clear_detection_tab, | |
inputs=[], | |
outputs=[input_text, file_input, check_button], | |
) | |
# Adding JavaScript to simulate file input click | |
gr.Markdown( | |
""" | |
<script> | |
document.addEventListener("DOMContentLoaded", function() { | |
const uploadButton = Array.from(document.getElementsByTagName('button')).find(el => el.innerText === "Upload File"); | |
if (uploadButton) { | |
uploadButton.onclick = function() { | |
document.querySelector('input[type="file"]').click(); | |
}; | |
} | |
}); | |
</script> | |
""" | |
) | |
with gr.Tab('Challenge Yourself!'): | |
gr.Markdown( | |
""" | |
<style> | |
.gr-button-secondary { | |
width: 100px; | |
height: 30px; | |
padding: 5px; | |
} | |
.gr-row { | |
display: flex; | |
align-items: center; | |
gap: 10px; | |
} | |
.gr-block { | |
padding: 20px; | |
} | |
.gr-markdown p { | |
font-size: 16px; | |
} | |
</style> | |
<span style='font-family: Arial, sans-serif; font-size: 20px;'>Was this text written by <strong>human</strong> or <strong>AI</strong>?</span> | |
<p style='font-family: Arial, sans-serif;'>Try detecting one of our sample texts:</p> | |
""" | |
) | |
with gr.Row(): | |
generate = gr.Button("Generate Sample Text", variant="primary") | |
clear = gr.ClearButton([], variant="stop") | |
with gr.Row(): | |
text = gr.Textbox(value="", label="Text", lines=20, interactive=False) | |
with gr.Row(): | |
mg = gr.Button("Machine-Generated", variant="secondary", interactive=False) | |
hw = gr.Button("Human-Written", variant="secondary", interactive=False) | |
mh = gr.Button("Machine-Humanized", variant="secondary", interactive=False) | |
mp = gr.Button("Machine-Polished", variant="secondary", interactive=False) | |
with gr.Row(): | |
result = gr.Label(label="Result", value="") | |
clear.add([result, text]) | |
generate.click(generate_text_challenge_tab, [], [text, mg, hw, mh, mp, result]) | |
for button in [mg, hw, mh, mp]: | |
button.click(show_result_challenge_tab, [button], [result, mg, hw, mh, mp]) | |
clear.click(clear_challenge_tab, [], [mg, hw, mh, mp, result]) | |
demo.launch(share=False) | |