muyiiwaa's picture
Update app.py
ed262da verified
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import string
import logging
# --- Set Page Config FIRST ---
# This MUST be the first Streamlit command
st.set_page_config(page_title="AI Text Detector", layout="wide")
# --- Configuration ---
MODEL_NAME = "muyiiwaa/modernbert_ai_human_text"
MODEL_MAX_LENGTH = 512
# Configure basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Model Loading (Cached for HF Spaces) ---
@st.cache_resource # Caches the model/tokenizer across user sessions on Spaces
def load_resources():
"""Loads the transformer model and tokenizer."""
logger.info(f"Attempting to load model: {MODEL_NAME}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
model.to(device)
model.eval()
logger.info("Model and tokenizer loaded successfully.")
return tokenizer, model, device
except Exception as e:
logger.error(f"Error loading model/tokenizer: {e}", exc_info=True)
# This st.error() call is why set_page_config must come earlier
st.error(f"Error loading model '{MODEL_NAME}'. Please check logs or model availability. Error: {e}")
return None, None, None
# Attempt to load model/tokenizer at app startup
tokenizer, model, device = load_resources()
# --- Prediction Function ---
# (Prediction function remains the same as before)
def predict(text: str):
# ... (rest of the predict function)
if not model or not tokenizer or not device:
return None # Indicate failure if resources aren't loaded
try:
# Basic preprocessing
processed_text = text.translate(str.maketrans('', '', string.punctuation)).strip()
if not processed_text:
return {"error": "Input text becomes empty after removing punctuation."}
inputs = tokenizer(
processed_text,
padding="max_length",
truncation=True,
max_length=MODEL_MAX_LENGTH,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
probabilities = F.softmax(outputs.logits, dim=1).squeeze()
probs_cpu = probabilities.cpu().numpy()
predicted_class_id = probs_cpu.argmax().item()
predicted_label = "AI-generated" if predicted_class_id == 1 else "Human-written"
return {
"human_prob": float(probs_cpu[0]),
"ai_prob": float(probs_cpu[1]),
"prediction": predicted_label
}
except Exception as e:
logger.error(f"Error during prediction: {e}", exc_info=True)
# Return error info instead of calling st.error directly here
return {"error": f"Analysis failed: {e}"}
# --- Streamlit App UI ---
# st.set_page_config(...) WAS HERE - NOW MOVED TO TOP
st.header("AI Text Detector Demo")
st.caption(f"Using model: `{MODEL_NAME}`")
input_text = st.text_area("Enter text to analyze:", height=150, placeholder="Paste or type text here...")
if st.button("Analyze", type="primary", disabled=(model is None)):
if not input_text.strip():
st.warning("Please enter some text.")
elif model: # Check again if model is loaded before attempting prediction
with st.spinner("Analyzing..."):
result = predict(input_text)
st.markdown("---") # Separator
if result and "error" not in result:
st.subheader("Result")
pred_label = result['prediction']
ai_prob = result['ai_prob']
if pred_label == "AI-generated":
st.error(f"**Prediction: {pred_label}** (Confidence: {ai_prob:.1%})")
else:
st.success(f"**Prediction: {pred_label}** (Human Confidence: {result['human_prob']:.1%})")
# Simple progress bars for probabilities
st.progress(result['human_prob'], text=f"Human Probability: {result['human_prob']:.1%}")
st.progress(result['ai_prob'], text=f"AI Probability: {result['ai_prob']:.1%}")
elif result and "error" in result:
st.error(f"Analysis Error: {result['error']}") # Display error message from predict()
else:
# Handles case where predict function returns None unexpectedly
st.error("Analysis failed for an unknown reason.")
else:
st.error("Model not loaded. Cannot analyze.")