Spaces:
Runtime error
Runtime error
from spacy.lang.en import English | |
from spacy.lang.en.stop_words import STOP_WORDS | |
import pandas as pd | |
import gradio as gr | |
from transformers import pipeline | |
from gradio.themes.utils.colors import red, green | |
import requests | |
import json | |
import os | |
from dotenv import load_dotenv | |
import time | |
# Load environment variables | |
load_dotenv() | |
# Initialize the NLP pipeline | |
nlp = English() | |
nlp.add_pipe("sentencizer") | |
tokenizer = nlp.tokenizer | |
# Initialize the text classification pipeline | |
detector = pipeline(task='text-classification', model='SJTU-CL/RoBERTa-large-ArguGPT-sent') | |
# Groq API configuration | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
if not GROQ_API_KEY: | |
raise ValueError("Please set your GROQ_API_KEY in the .env file") | |
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions" | |
GROQ_MODEL = "llama3-70b-8192" # Updated to latest model | |
# Define color map for highlighted text | |
color_map = { | |
'0%': green.c400, | |
'10%': green.c300, | |
'20%': green.c200, | |
'30%': green.c100, | |
'40%': green.c50, | |
'50%': red.c50, | |
'60%': red.c100, | |
'70%': red.c200, | |
'80%': red.c300, | |
'90%': red.c400, | |
'100%': red.c500 | |
} | |
def is_stopword(word): | |
"""Check if a word is a stop word or very short""" | |
return word.lower() in STOP_WORDS or len(word) <= 2 | |
def get_synonyms(word): | |
"""Get simple, human-readable synonyms using Groq API""" | |
if is_stopword(word): | |
return [word] # Don't get synonyms for stop words | |
headers = { | |
"Authorization": f"Bearer {GROQ_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
prompt = f"""Provide a list of exactly 5 simple synonyms for '{word}'. | |
Return ONLY a JSON array of synonyms without any additional text. | |
Example: ["use", "employ", "apply", "make use of", "take"]""" | |
data = { | |
"model": GROQ_MODEL, | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.3, | |
"max_tokens": 100, | |
"response_format": {"type": "json_object"} | |
} | |
try: | |
response = requests.post(GROQ_API_URL, headers=headers, json=data) | |
response.raise_for_status() | |
result = response.json() | |
content = json.loads(result['choices'][0]['message']['content']) | |
# Improved response parsing | |
if isinstance(content, dict): | |
for key in ['synonyms', 'words', 'alternatives']: | |
if key in content and isinstance(content[key], list): | |
return content[key][:5] | |
return [word] # Fallback if parsing fails | |
except Exception as e: | |
print(f"Error getting synonyms: {e}") | |
return [word] # Fallback to original word if API fails | |
def identify_problem_words(text): | |
"""Use Groq API to identify uncommon, difficult, and AI-generated words""" | |
headers = { | |
"Authorization": f"Bearer {GROQ_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
prompt = f"""Analyze this text and return ONLY a JSON list of words that are: | |
1. Uncommon (not in everyday vocabulary) | |
2. Difficult (complex or technical) | |
3. Likely AI-generated (overly formal, verbose, or unnatural) | |
Exclude all stop words (a, an, the, and, but, etc.) and very short words (1-2 letters). | |
Return format: {{"words": ["word1", "word2", ...]}} | |
Text: {text}""" | |
data = { | |
"model": GROQ_MODEL, | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.2, | |
"max_tokens": 200, | |
"response_format": {"type": "json_object"} | |
} | |
try: | |
response = requests.post(GROQ_API_URL, headers=headers, json=data) | |
response.raise_for_status() | |
result = response.json() | |
content = json.loads(result['choices'][0]['message']['content']) | |
if isinstance(content, dict) and 'words' in content: | |
# Filter out any stop words that might have slipped through | |
filtered_words = [word for word in content['words'] if not is_stopword(word)] | |
return set(filtered_words) | |
return set() | |
except Exception as e: | |
print(f"Error identifying problem words: {e}") | |
return set() | |
def predict_word(word, problem_words): | |
"""Predict AI probability for a single word if it's in problem words""" | |
if len(word) <= 3 or word.lower() not in problem_words or is_stopword(word): | |
return 0.0 | |
try: | |
prob = predict_one_sent(word) | |
return prob | |
except: | |
return 0.0 | |
def predict_doc(doc): | |
start_time = time.time() | |
# First identify problem words using Groq | |
problem_words = identify_problem_words(doc) | |
print(f"Identified problem words: {problem_words}") | |
sents = [s.text for s in nlp(doc).sents] | |
data = {'sentence': [], 'label': [], 'score': []} | |
sent_res = [] | |
word_highlights = [] | |
for sent in sents: | |
sent_prob = predict_one_sent(sent) | |
# Word-level analysis - only for problem words | |
tokens = [token.text for token in tokenizer(sent)] | |
word_probs = [predict_word(token, problem_words) for token in tokens] | |
for word, prob in zip(tokens, word_probs): | |
if prob >= 0.2: # Only highlight words with >20% AI probability | |
if prob < 0.3: label = '20%' | |
elif prob < 0.4: label = '30%' | |
elif prob < 0.5: label = '40%' | |
elif prob < 0.6: label = '50%' | |
elif prob < 0.7: label = '60%' | |
elif prob < 0.8: label = '70%' | |
elif prob < 0.9: label = '80%' | |
elif prob < 1: label = '90%' | |
else: label = '100%' | |
word_highlights.append((word, label)) | |
else: | |
word_highlights.append((word, None)) | |
data['sentence'].append(sent) | |
data['score'].append(round(sent_prob, 4)) | |
if sent_prob <= 0.5: | |
data['label'].append('Human') | |
else: | |
data['label'].append('Machine') | |
if sent_prob < 0.1: label = '0%' | |
elif sent_prob < 0.2: label = '10%' | |
elif sent_prob < 0.3: label = '20%' | |
elif sent_prob < 0.4: label = '30%' | |
elif sent_prob < 0.5: label = '40%' | |
elif sent_prob < 0.6: label = '50%' | |
elif sent_prob < 0.7: label = '60%' | |
elif sent_prob < 0.8: label = '70%' | |
elif sent_prob < 0.9: label = '80%' | |
elif sent_prob < 1: label = '90%' | |
else: label = '100%' | |
sent_res.append((sent, label)) | |
df = pd.DataFrame(data) | |
csv_path = 'result.csv' | |
df.to_csv(csv_path) | |
print(f"Analysis took {time.time() - start_time:.2f} seconds") | |
overall_score = df.score.mean() | |
overall_label = 'Human' if overall_score <= 0.5 else 'Machine' | |
sum_str = f'The essay is probably written by {overall_label}. The probability of being generated by AI is {overall_score:.2f}' | |
return sum_str, sent_res, df, csv_path, word_highlights | |
def predict_one_sent(sent): | |
res = detector(sent)[0] | |
org_label, prob = res['label'], res['score'] | |
if org_label == 'LABEL_0': prob = 1 - prob | |
return prob | |
def update_text(text, selected_word, replacement, word_highlights): | |
new_text = text.replace(selected_word, replacement, 1) | |
# Update word_highlights with the new word (assuming it's now human-written) | |
updated_highlights = [] | |
replaced = False | |
for word, label in word_highlights: | |
if word == selected_word and not replaced: | |
updated_highlights.append((replacement, '0%')) | |
replaced = True | |
else: | |
updated_highlights.append((word, label)) | |
return new_text, updated_highlights | |
def process_word_highlights(highlights): | |
return highlights | |
# Custom CSS for modern look | |
custom_css = """ | |
.gradio-container { | |
font-family: 'Arial', sans-serif; | |
} | |
.gradio-header { | |
background-color: #4CAF50; | |
color: white; | |
padding: 10px; | |
text-align: center; | |
border-radius: 8px; | |
margin-bottom: 20px; | |
} | |
.gradio-button { | |
background-color: #4CAF50; | |
color: white; | |
border: none; | |
padding: 10px 20px; | |
text-align: center; | |
text-decoration: none; | |
display: inline-block; | |
font-size: 16px; | |
margin: 4px 2px; | |
cursor: pointer; | |
border-radius: 5px; | |
transition: background-color 0.3s; | |
} | |
.gradio-button:hover { | |
background-color: #45a049; | |
} | |
.highlighted-word { | |
cursor: pointer; | |
padding: 2px 4px; | |
border-radius: 3px; | |
transition: all 0.2s; | |
} | |
.highlighted-word:hover { | |
text-decoration: underline; | |
background-color: #f0f0f0; | |
transform: scale(1.05); | |
} | |
.replacement-row { | |
border: 1px solid #ddd; | |
padding: 15px; | |
border-radius: 8px; | |
margin-top: 10px; | |
background-color: #f9f9f9; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown("""## AI vs Human Essay Detector""") | |
gr.Markdown("""Identify and replace uncommon, difficult, and AI-generated words in your text.""") | |
word_highlights = gr.State([]) | |
selected_word = gr.State("") | |
with gr.Row(): | |
with gr.Column(): | |
text_in = gr.Textbox( | |
lines=10, | |
label='Essay Input', | |
placeholder="Paste your essay here...", | |
elem_classes=["text-input"] | |
) | |
btn = gr.Button('Analyze Text', variant="primary") | |
with gr.Column(): | |
sent_res = gr.HighlightedText( | |
label='Sentence-level Analysis', | |
color_map=color_map, | |
show_legend=True | |
) | |
word_res = gr.HighlightedText( | |
label='Word-level Analysis (Click words to replace)', | |
color_map=color_map, | |
show_legend=True | |
) | |
with gr.Row(): | |
summary = gr.Textbox(label='Overall Analysis', interactive=False) | |
csv_f = gr.File(label='Download Detailed Analysis') | |
with gr.Row(): | |
tab = gr.Dataframe( | |
label='Detailed Sentence Analysis', | |
wrap=True, | |
max_rows=10 | |
) | |
with gr.Column(visible=False) as replacement_row: | |
gr.Markdown("### Replace Word") | |
with gr.Row(): | |
replacement_dropdown = gr.Dropdown( | |
label="Select replacement", | |
interactive=True, | |
allow_custom_value=True | |
) | |
with gr.Row(): | |
replace_btn = gr.Button("Replace", variant="primary") | |
cancel_btn = gr.Button("Cancel") | |
def on_word_select(evt: gr.SelectData): | |
if evt.value: | |
synonyms = get_synonyms(evt.value) | |
return ( | |
evt.value, | |
gr.Dropdown(choices=synonyms, value=evt.value), | |
gr.Column(visible=True) | |
) | |
return None, None, gr.Column(visible=False) | |
word_res.select( | |
fn=on_word_select, | |
outputs=[selected_word, replacement_dropdown, replacement_row] | |
) | |
replace_btn.click( | |
fn=update_text, | |
inputs=[text_in, selected_word, replacement_dropdown, word_highlights], | |
outputs=[text_in, word_highlights] | |
).then( | |
fn=lambda: gr.Column(visible=False), | |
outputs=replacement_row | |
).then( | |
fn=lambda x: predict_doc(x), | |
inputs=text_in, | |
outputs=[summary, sent_res, tab, csv_f, word_highlights] | |
).then( | |
fn=process_word_highlights, | |
inputs=word_highlights, | |
outputs=word_res | |
) | |
cancel_btn.click( | |
fn=lambda: gr.Column(visible=False), | |
outputs=replacement_row | |
) | |
btn.click( | |
fn=predict_doc, | |
inputs=text_in, | |
outputs=[summary, sent_res, tab, csv_f, word_highlights] | |
).then( | |
fn=process_word_highlights, | |
inputs=word_highlights, | |
outputs=word_res | |
) | |
if __name__ == "__main__": | |
demo.launch() |