Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
import numpy as np | |
import nltk | |
import os | |
from nltk.tokenize import sent_tokenize | |
from transformers import DistilBertTokenizerFast, TFDistilBertForSequenceClassification | |
# ๐ Hugging Face cache dir | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
# ๐ฅ Download NLTK punkt tokenizer | |
nltk_data_path = "/tmp/nltk_data" | |
nltk.download("punkt_tab", download_dir=nltk_data_path) | |
nltk.data.path.append(nltk_data_path) | |
# โ Cache the model/tokenizer | |
def load_model_and_tokenizer(): | |
tokenizer = DistilBertTokenizerFast.from_pretrained( | |
"distilbert-base-uncased", cache_dir="/tmp/huggingface" | |
) | |
model = TFDistilBertForSequenceClassification.from_pretrained( | |
"sundaram07/distilbert-sentence-classifier", cache_dir="/tmp/huggingface" | |
) | |
return tokenizer, model | |
tokenizer, model = load_model_and_tokenizer() | |
# ๐ฎ Predict sentence AI probability | |
def predict_sentence_ai_probability(sentence): | |
inputs = tokenizer(sentence, return_tensors="tf", truncation=True, padding=True) | |
outputs = model(inputs) | |
logits = outputs.logits | |
prob_ai = tf.sigmoid(logits)[0][0].numpy() | |
return prob_ai | |
# ๐ Analyze text | |
def predict_ai_generated_percentage(text, threshold=0.15): | |
text = text.strip() | |
sentences = sent_tokenize(text) | |
if len(sentences) == 0: | |
return 0.0, [] | |
ai_sentence_count = 0 | |
results = [] | |
for sentence in sentences: | |
prob = predict_sentence_ai_probability(sentence) | |
is_ai = prob <= threshold | |
results.append((sentence, prob, is_ai)) | |
if is_ai: | |
ai_sentence_count += 1 | |
ai_percentage = (ai_sentence_count / len(sentences)) * 100 | |
return ai_percentage, results | |
# ๐ฅ๏ธ Streamlit UI | |
st.set_page_config(page_title="AI Detector", layout="wide") | |
st.title("๐ง AI Content Detector") | |
st.markdown("This app detects the percentage of **AI-generated content** using sentence-level analysis with DistilBERT.") | |
# ๐ Text input | |
user_input = st.text_area("๐ Paste your text below to check for AI-generated sentences:", height=300) | |
# ๐ Analyze button logic | |
if st.button("๐ Analyze"): | |
# Clear previous session results | |
st.session_state.analysis_done = False | |
st.session_state.analysis_results = None | |
st.session_state.ai_percentage = None | |
if not user_input.strip(): | |
st.warning("โ ๏ธ Please enter some text.") | |
else: | |
# Perform analysis | |
ai_percentage, analysis_results = predict_ai_generated_percentage(user_input) | |
if len(analysis_results) == 0: | |
st.warning("โ ๏ธ Not enough valid sentences to analyze.") | |
else: | |
st.session_state.analysis_done = True | |
st.session_state.analysis_results = analysis_results | |
st.session_state.ai_percentage = ai_percentage | |
# ๐ค Show results | |
if st.session_state.get("analysis_done", False): | |
st.subheader("๐ Sentence-level Analysis") | |
for i, (sentence, prob, is_ai) in enumerate(st.session_state.analysis_results, start=1): | |
label = "๐ข Human" if not is_ai else "๐ด AI" | |
st.markdown(f"**{i}.** _{sentence}_\n\n โ {label}") | |
st.subheader("๐ Final Result") | |
st.success(f"Estimated **AI-generated content**: **{st.session_state.ai_percentage:.2f}%**") | |