Spaces:
Runtime error
Runtime error
import nltk | |
nltk.download('punkt') | |
import pandas as pd | |
import gradio as gr | |
from nltk import sent_tokenize | |
from transformers import pipeline | |
from gradio.themes.utils.colors import red, green | |
detector = pipeline(task='text-classification', model='SJTU-CL/RoBERTa-large-ArguGPT-sent') | |
color_map = { | |
'0%': green.c400, | |
'10%': green.c300, | |
'20%': green.c200, | |
'30%': green.c100, | |
'40%': green.c50, | |
'50%': red.c50, | |
'60%': red.c100, | |
'70%': red.c200, | |
'80%': red.c300, | |
'90%': red.c400, | |
'100%': red.c500 | |
} | |
def predict_doc(doc): | |
sents = sent_tokenize(doc) | |
data = {'sentence': [], 'label': [], 'score': []} | |
res = [] | |
for sent in sents: | |
prob = predict_one_sent(sent) | |
data['sentence'].append(sent) | |
data['score'].append(round(prob, 4)) | |
if prob <= 0.5: | |
data['label'].append('Human') | |
else: data['label'].append('Machine') | |
if prob < 0.1: label = '0%' | |
elif prob < 0.2: label = '10%' | |
elif prob < 0.3: label = '20%' | |
elif prob < 0.4: label = '30%' | |
elif prob < 0.5: label = '40%' | |
elif prob < 0.6: label = '50%' | |
elif prob < 0.7: label = '60%' | |
elif prob < 0.8: label = '70%' | |
elif prob < 0.9: label = '80%' | |
elif prob < 1: label = '90%' | |
else: label = '100%' | |
res.append((sent, label)) | |
df = pd.DataFrame(data) | |
df.to_csv('result.csv') | |
overall_score = df.score.mean() | |
sum_str = '' | |
if overall_score <= 0.5: overall_label = 'Human' | |
else: overall_label = 'Machine' | |
sum_str = f'The essay is probably written by {overall_label}. The probability of being generated by AI is {overall_score}' | |
return sum_str, res, df, 'result.csv' | |
def predict_one_sent(sent): | |
''' | |
convert to prob | |
LABEL_1, 0.66 -> 0.66 | |
LABEL_0, 0.66 -> 0.34 | |
''' | |
res = detector(sent)[0] | |
org_label, prob = res['label'], res['score'] | |
if org_label == 'LABEL_0': prob = 1 - prob | |
return prob | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
text_in = gr.Textbox( | |
lines=5, | |
label='Essay input', | |
info='Please enter the essay in the textbox' | |
) | |
btn = gr.Button('Predict who writes this essay!') | |
sent_res = gr.HighlightedText(label='Labeled Result', color_map=color_map) | |
with gr.Row(): | |
summary = gr.Text(label='Result summary') | |
csv_f = gr.File(label='CSV file storing data with all sentences.') | |
tab = gr.Dataframe(label='Table with Probability Score', row_count=100) | |
btn.click(predict_doc, inputs=[text_in], outputs=[summary, sent_res, tab, csv_f], api_name='predict_doc') | |
demo.launch() | |