fserfati13's picture
Rename app_bert.py to app.py
3422250 verified
raw history blame
No virus
1.94 kB
import streamlit as st
import torch
from transformers import RobertaTokenizer
from bert import RobertaClass
from text_preprocessing import preprocess_text
# Load the fine-tuned BERT model
model = RobertaClass()
model.load_state_dict(torch.load('model_bert_2.bin',
map_location=torch.device('cpu')))
# Load the tokenizer
tokenizer = RobertaTokenizer.from_pretrained(
'roberta-base', truncation=True, do_lower_case=True)
# Define the user interface
st.title('ChatGPT detector')
text_input = st.text_input('Enter text to classify:', '')
submit_button = st.button('Classify')
# Define prediction function
def predict(text):
'''Predicts the label and confidence level of the input text.'''
# Preprocess the input text
text_preprocessed = preprocess_text(text)
# Tokenize the preprocessed text
inputs = tokenizer(text_preprocessed, return_tensors='pt',
padding=True, truncation=True)
inputs.pop('token_type_ids', None) # Remove token_type_ids
# Perform inference
with torch.no_grad():
outputs = model(input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'])
# Convert output to probabilities and predicted label
# probability of positive class
predicted_prob = torch.sigmoid(outputs).item()
predicted_label = 1 if predicted_prob >= 0.5 else 0
if predicted_label == 0:
predicted_prob = 1 - predicted_prob
return predicted_label, predicted_prob
# Handle user interaction
if submit_button:
predicted_label, predicted_prob = predict(text_input)
# Assuming binary classification
labels = ['written by a human', 'generated by ChatGPT']
predicted_category = labels[predicted_label]
predicted_prob_percentage = round(predicted_prob * 100, 2)
st.write(
f"This text was {predicted_category} ({predicted_prob_percentage} % confident)")