Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import string | |
| import logging | |
| # --- Set Page Config FIRST --- | |
| # This MUST be the first Streamlit command | |
| st.set_page_config(page_title="AI Text Detector", layout="wide") | |
| # --- Configuration --- | |
| MODEL_NAME = "muyiiwaa/modernbert_ai_human_text" | |
| MODEL_MAX_LENGTH = 512 | |
| # Configure basic logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- Model Loading (Cached for HF Spaces) --- | |
| # Caches the model/tokenizer across user sessions on Spaces | |
| def load_resources(): | |
| """Loads the transformer model and tokenizer.""" | |
| logger.info(f"Attempting to load model: {MODEL_NAME}") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| model.to(device) | |
| model.eval() | |
| logger.info("Model and tokenizer loaded successfully.") | |
| return tokenizer, model, device | |
| except Exception as e: | |
| logger.error(f"Error loading model/tokenizer: {e}", exc_info=True) | |
| # This st.error() call is why set_page_config must come earlier | |
| st.error(f"Error loading model '{MODEL_NAME}'. Please check logs or model availability. Error: {e}") | |
| return None, None, None | |
| # Attempt to load model/tokenizer at app startup | |
| tokenizer, model, device = load_resources() | |
| # --- Prediction Function --- | |
| # (Prediction function remains the same as before) | |
| def predict(text: str): | |
| # ... (rest of the predict function) | |
| if not model or not tokenizer or not device: | |
| return None # Indicate failure if resources aren't loaded | |
| try: | |
| # Basic preprocessing | |
| processed_text = text.translate(str.maketrans('', '', string.punctuation)).strip() | |
| if not processed_text: | |
| return {"error": "Input text becomes empty after removing punctuation."} | |
| inputs = tokenizer( | |
| processed_text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=MODEL_MAX_LENGTH, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = F.softmax(outputs.logits, dim=1).squeeze() | |
| probs_cpu = probabilities.cpu().numpy() | |
| predicted_class_id = probs_cpu.argmax().item() | |
| predicted_label = "AI-generated" if predicted_class_id == 1 else "Human-written" | |
| return { | |
| "human_prob": float(probs_cpu[0]), | |
| "ai_prob": float(probs_cpu[1]), | |
| "prediction": predicted_label | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {e}", exc_info=True) | |
| # Return error info instead of calling st.error directly here | |
| return {"error": f"Analysis failed: {e}"} | |
| # --- Streamlit App UI --- | |
| # st.set_page_config(...) WAS HERE - NOW MOVED TO TOP | |
| st.header("AI Text Detector Demo") | |
| st.caption(f"Using model: `{MODEL_NAME}`") | |
| input_text = st.text_area("Enter text to analyze:", height=150, placeholder="Paste or type text here...") | |
| if st.button("Analyze", type="primary", disabled=(model is None)): | |
| if not input_text.strip(): | |
| st.warning("Please enter some text.") | |
| elif model: # Check again if model is loaded before attempting prediction | |
| with st.spinner("Analyzing..."): | |
| result = predict(input_text) | |
| st.markdown("---") # Separator | |
| if result and "error" not in result: | |
| st.subheader("Result") | |
| pred_label = result['prediction'] | |
| ai_prob = result['ai_prob'] | |
| if pred_label == "AI-generated": | |
| st.error(f"**Prediction: {pred_label}** (Confidence: {ai_prob:.1%})") | |
| else: | |
| st.success(f"**Prediction: {pred_label}** (Human Confidence: {result['human_prob']:.1%})") | |
| # Simple progress bars for probabilities | |
| st.progress(result['human_prob'], text=f"Human Probability: {result['human_prob']:.1%}") | |
| st.progress(result['ai_prob'], text=f"AI Probability: {result['ai_prob']:.1%}") | |
| elif result and "error" in result: | |
| st.error(f"Analysis Error: {result['error']}") # Display error message from predict() | |
| else: | |
| # Handles case where predict function returns None unexpectedly | |
| st.error("Analysis failed for an unknown reason.") | |
| else: | |
| st.error("Model not loaded. Cannot analyze.") |