thugCodeNinja's picture
Update app.py
2209635 verified
raw
history blame
No virus
4.72 kB
import gradio as gr
import torch
from torch.nn.functional import softmax
import shap
import requests
from bs4 import BeautifulSoup
from sklearn.metrics.pairwise import cosine_similarity
from transformers import RobertaTokenizer,RobertaForSequenceClassification, pipeline
from IPython.core.display import HTML
model_dir = 'temp'
tokenizer = RobertaTokenizer.from_pretrained(model_dir)
model = RobertaForSequenceClassification.from_pretrained(model_dir)
#pipe = pipeline("text-classification", model="thugCodeNinja/robertatemp")
tokenizer1 = RobertaTokenizer.from_pretrained('roberta-base')
model1 = RobertaModel.from_pretrained('roberta-base')
pipe = pipeline("text-classification",model=model,tokenizer=tokenizer)
def process_text(input_text):
if input_text:
text = input_text
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = softmax(logits, dim=1)
max_prob, predicted_class_id = torch.max(probs, dim=1)
prob = str(round(max_prob.item() * 100, 2))
label = model.config.id2label[predicted_class_id.item()]
final_label='Human' if model.config.id2label[predicted_class_id.item()]=='LABEL_0' else 'Chat-GPT'
processed_result = text
def search(text):
query = text
api_key = 'AIzaSyClvkiiJTZrCJ8BLqUY9I38WYmbve8g-c8'
search_engine_id = '53d064810efa44ce7'
url = f'https://www.googleapis.com/customsearch/v1?key={api_key}&cx={search_engine_id}&q={query}'
try:
response = requests.get(url)
data = response.json()
return data
except Exception as e:
return {'error': str(e)}
def get_article_text(url):
try:
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
# Extract text from the article content (you may need to adjust this based on the website's structure)
article_text = ' '.join([p.get_text() for p in soup.find_all('p')])
return article_text
except Exception as e:
print(f"An error occurred: {e}")
return ''
def find_plagiarism(text):
search_results = search(text)
if 'items' not in search_results:
return []
similar_articles = []
for item in search_results['items']:
link = item.get('link', '')
article_text = get_article_text(link)
if article_text:
# Tokenize and encode the input text and the article text
encoding1 = tokenizer(text, max_length=512, truncation=True, padding=True, return_tensors="pt")
encoding2 = tokenizer(article_text, max_length=512, truncation=True, padding=True, return_tensors="pt")
# Calculate embeddings using the model
with torch.no_grad():
embedding1 = model(**encoding1).last_hidden_state.mean(dim=1)
embedding2 = model(**encoding2).last_hidden_state.mean(dim=1)
# Calculate cosine similarity between the input text and the article text embeddings
similarity = cosine_similarity(embedding1, embedding2)[0][0]
similar_articles.append({'Link': link, 'Similarity': similarity})
similar_articles = sorted(similar_articles, key=lambda x: x['Similarity'], reverse=True)
threshold = 0.5 # Adjust the threshold as needed
similar_articles = [article for article in similar_articles if article['Similarity'] > threshold]
return similar_articles[:5]
prediction = pipe([text])
explainer = shap.Explainer(pipe)
shap_values = explainer([text])
shap_plot_html = HTML(shap.plots.text(shap_values, display=False)).data
similar_articles = find_plagiarism(text)
return processed_result, prob, final_label, shap_plot_html,similar_articles
text_input = gr.Textbox(label="Enter text")
outputs = [gr.Textbox(label="Processed text"), gr.Textbox(label="Probability"), gr.Textbox(label="Label"), gr.HTML(label="SHAP Plot"),gr.Dataframe(label="Similar Articles", headers=["Title", "Link"],row_count=5)]
title = "Group 2- ChatGPT text detection module"
description = '''Please upload text files and text input responsibly and await the explainable results. The approach in place includes finetuning a Roberta model for text classification.Once the classifications are done the decision is exaplined thorugh the SHAP text plot.
The probability is particularly explained by the attention plots through SHAP'''
gr.Interface(fn=process_text,title=title,description=description, inputs=[text_input], outputs=outputs).launch()