Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| # ----------------- | |
| # Get the directory where app.py is located | |
| # ----------------- | |
| APP_DIR = Path(__file__).parent.resolve() | |
| account_name = 'mamba413' | |
| # ----------------- | |
| # Fix Streamlit Permission Issues | |
| # ----------------- | |
| # 在 HF Space 中,将 Streamlit 配置目录设置到可写位置 | |
| if os.environ.get('SPACE_ID'): | |
| os.environ['STREAMLIT_SERVER_FILE_WATCHER_TYPE'] = 'none' | |
| os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false' | |
| os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'false' | |
| # 设置 HuggingFace 缓存到可写目录 | |
| CACHE_DIR = '/tmp/huggingface_cache' | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| os.environ['HF_HOME'] = CACHE_DIR | |
| os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR | |
| os.environ['HF_DATASETS_CACHE'] = CACHE_DIR | |
| os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR | |
| # 设置可写的配置目录 | |
| streamlit_dir = Path('/tmp/.streamlit') | |
| streamlit_dir.mkdir(exist_ok=True, parents=True) | |
| # os.environ['STREAMLIT_HOME'] = '/tmp/.streamlit' | |
| import streamlit as st | |
| from FineTune.model import ComputeStat | |
| import time | |
| st.markdown( | |
| """ | |
| <style> | |
| /* Text area & text input */ | |
| textarea, input[type="text"] { | |
| background-color: #f8fafc !important; | |
| border: 1px solid #e5e7eb !important; | |
| color: #111827 !important; | |
| } | |
| textarea::placeholder { | |
| color: #9ca3af !important; | |
| } | |
| /* Selectbox */ | |
| div[data-testid="stSelectbox"] > div { | |
| background-color: #f8fafc !important; | |
| border: 1px solid #e5e7eb !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown( | |
| """ | |
| <style> | |
| /* Detect button */ | |
| div.stButton > button[kind="primary"] { | |
| background-color: #fdae6b; | |
| border: white; | |
| color: black; | |
| font-weight: 600; | |
| height: 4.3rem; | |
| font-size: 1.1rem; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 0.55rem; | |
| } | |
| /* Icon inside Detect button */ | |
| div.stButton > button[kind="primary"] span { | |
| font-size: 1.25rem; | |
| line-height: 1; | |
| } | |
| div.stButton > button[kind="primary"]:hover { | |
| background-color: #fd8d3c; | |
| border-color: white; | |
| } | |
| div.stButton > button[kind="primary"]:active { | |
| background-color: #fd8d3c; | |
| border-color: white; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # ----------------- | |
| # Page Configuration | |
| # ----------------- | |
| st.set_page_config( | |
| page_title="DetectGPTPro", | |
| page_icon="🕵️", | |
| ) | |
| # ----------------- | |
| # Model Loading (Cached) | |
| # ----------------- | |
| def load_model(from_pretrained, base_model, cache_dir, device): | |
| """ | |
| Load and cache the model to avoid reloading on every user interaction. | |
| This function runs only once when the app starts or when parameters change. | |
| """ | |
| # is_hf_space = os.environ.get('SPACE_ID') is not None | |
| is_hf_space = False | |
| if is_hf_space: | |
| cache_dir = '/tmp/huggingface_cache' | |
| os.makedirs(cache_dir, exist_ok=True) | |
| device = 'cpu' | |
| print("Using **CPU** now!") | |
| # 获取 HF Token(用于访问 gated 模型) | |
| hf_token = os.environ.get('HF_TOKEN', None) | |
| if hf_token: | |
| # 也可以用 login 方式 | |
| try: | |
| from huggingface_hub import login | |
| login(token=hf_token) | |
| print("✅ Successfully authenticated with HF token") | |
| except Exception as e: | |
| print(f"⚠️ HF login warning: {e}") | |
| # 🔥 新增:从 HF Hub 下载模型 | |
| # 检查是否是 HF Hub 路径(格式:username/repo-name) | |
| is_hf_hub = '/' in from_pretrained and not from_pretrained.startswith('.') | |
| if is_hf_hub: | |
| from huggingface_hub import snapshot_download | |
| print(f"📥 Downloading model from HuggingFace Hub: {from_pretrained}") | |
| try: | |
| # 下载整个仓库到本地 | |
| local_model_path = snapshot_download( | |
| repo_id=from_pretrained, | |
| cache_dir=cache_dir, | |
| token=hf_token, | |
| repo_type="model" | |
| ) | |
| print(f"✅ Model downloaded to: {local_model_path}") | |
| # 使用下载后的本地路径 | |
| from_pretrained = local_model_path | |
| except Exception as e: | |
| print(f"❌ Failed to download model: {e}") | |
| raise | |
| else: | |
| cache_dir = cache_dir | |
| with st.spinner("🔄 Loading model... This may take a moment on first launch."): | |
| model = ComputeStat.from_pretrained( | |
| from_pretrained, | |
| base_model, | |
| device=device, | |
| cache_dir=cache_dir | |
| ) | |
| model.set_criterion_fn('mean') | |
| return model | |
| # ----------------- | |
| # Result Feedback Module Import | |
| # ----------------- | |
| from feedback import FeedbackManager | |
| # Initialize Feedback Manager with HF dataset | |
| # make sure HF_TOKEN is set to visit private repository | |
| FEEDBACK_DATASET_ID = os.environ.get('FEEDBACK_DATASET_ID', f'{account_name}/user-feedback') | |
| feedback_manager = FeedbackManager( | |
| dataset_repo_id=FEEDBACK_DATASET_ID, | |
| hf_token=os.environ.get('HF_TOKEN'), | |
| local_backup=False if os.environ.get('SPACE_ID') else True # 保留本地备份 | |
| ) | |
| # ----------------- | |
| # Configuration | |
| # ----------------- | |
| MODEL_CONFIG = { | |
| 'from_pretrained': './src/FineTune/ckpt/', | |
| 'base_model': 'gemma-1b', | |
| 'cache_dir': '../cache', | |
| 'device': 'cpu' if os.environ.get('SPACE_ID') else 'mps', | |
| # 'device': 'cuda', | |
| } | |
| DOMAINS = [ | |
| "General", | |
| "Academia", | |
| "Finance", | |
| "Government", | |
| "Knowledge", | |
| "Legislation", | |
| "Medicine", | |
| "News", | |
| "UserReview" | |
| ] | |
| # Load model once at startup | |
| try: | |
| model = load_model( | |
| MODEL_CONFIG['from_pretrained'], | |
| MODEL_CONFIG['base_model'], | |
| MODEL_CONFIG['cache_dir'], | |
| MODEL_CONFIG['device'] | |
| ) | |
| model_loaded = True | |
| except Exception as e: | |
| model_loaded = False | |
| error_message = str(e) | |
| # =========== 🆕 session_state =========== | |
| if 'last_detection' not in st.session_state: | |
| st.session_state.last_detection = None | |
| if 'feedback_given' not in st.session_state: | |
| st.session_state.feedback_given = False | |
| # ======================================== | |
| # ----------------- | |
| # Streamlit Layout | |
| # ----------------- | |
| st.markdown( | |
| "<h1 style='text-align: center;'> Detect AI-Generated Texts 🕵️ </h1>", | |
| unsafe_allow_html=True, | |
| ) | |
| # st.markdown( | |
| # """Pasted the text to be detected below and click the 'Detect' button to get the p-value. Use a better option may improve detection.""" | |
| # ) | |
| # Display model loading status | |
| if not model_loaded: | |
| st.error(f"❌ Failed to load model: {error_message}") | |
| st.stop() | |
| # ----------------- | |
| # Main Interface | |
| # ----------------- | |
| # --- Two columns: Input text & button | Result displays --- | |
| text_input = st.text_area( | |
| label="📝 Input Text to be Detected", | |
| placeholder="Paste your text here", | |
| height=240, | |
| label_visibility="hidden", | |
| ) | |
| subcol11, subcol12, subcol13 = st.columns((1, 1, 1)) | |
| selected_domain = subcol11.selectbox( | |
| label="💡 Domain that matches your text", | |
| options=DOMAINS, | |
| index=0, # Default to General | |
| # label_visibility="collapsed", | |
| # label_visibility="hidden", | |
| ) | |
| detect_clicked = subcol12.button("🔍 Detect", type="primary", use_container_width=True) | |
| selected_level = subcol13.slider( | |
| label="Significance level (α)", | |
| min_value=0.01, | |
| max_value=0.2, | |
| value=0.05, | |
| step=0.005, | |
| # label_visibility="collapsed", | |
| ) | |
| # ----------------- | |
| # Detection Logic | |
| # ----------------- | |
| if detect_clicked: | |
| if not text_input.strip(): | |
| st.warning("⚠️ Please enter some text before detecting.") | |
| else: | |
| # ========== Reset feedback state ========== | |
| st.session_state.feedback_given = False | |
| # ========================================== | |
| # Start timing to decide whether to show progress bar | |
| start_time = time.time() | |
| # Use a placeholder for dynamic updates | |
| status_placeholder = st.empty() | |
| result_placeholder = st.empty() | |
| try: | |
| # Show spinner for quick operations (< 2 seconds expected) | |
| with status_placeholder: | |
| with st.spinner(f"🔍 Analyzing text in {selected_domain} domain..."): | |
| # Perform inference | |
| crit, p_value = model.compute_p_value(text_input, selected_domain) | |
| elapsed_time = time.time() - start_time | |
| # Convert tensors to Python scalars if needed | |
| if hasattr(crit, 'item'): | |
| crit = crit.item() | |
| if hasattr(p_value, 'item'): | |
| p_value = p_value.item() | |
| # Clear status and show results | |
| status_placeholder.empty() | |
| # ========== 🆕 保存检测结果到 session_state ========== | |
| st.session_state.last_detection = { | |
| 'text': text_input, | |
| 'domain': selected_domain, | |
| 'statistics': crit, | |
| 'p_value': p_value, | |
| 'elapsed_time': elapsed_time | |
| } | |
| st.info( | |
| f""" | |
| **Conclusion**: | |
| {'Text is likely LLM-generated.' if p_value < selected_level else 'Fail to reject hypothesis that text is human-written.'} | |
| based on the observation that $p$-value {p_value:.3f} is {'less' if p_value < selected_level else 'greater'} than significance level {selected_level:.2f} 📊 | |
| """, | |
| icon="💡" | |
| ) | |
| st.markdown( | |
| """ | |
| <style> | |
| /* Tighten spacing inside Clarification / Citation expanders */ | |
| div[data-testid="stExpander"] { | |
| margin-top: -1.3rem; | |
| } | |
| div[data-testid="stExpander"] p, | |
| div[data-testid="stExpander"] li { | |
| line-height: 1.35; | |
| margin-bottom: 0.1rem; | |
| } | |
| div[data-testid="stExpander"] ul { | |
| margin-top: 0.1rem; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| with st.expander("📋 Interpretation and Suggestions"): | |
| st.markdown( | |
| """ | |
| + Interpretation: | |
| - $p$-value: Lower $p$-value (closer to 0) indicates text is **more likely AI-generated**; Higher $p$-value (closer to 1) indicates text is **more likely human-written**. | |
| - Significance Level (α): a threshold set by the user to determine the sensitivity of the detection. Lower α means stricter criteria for claiming the text is AI-generated. | |
| + Suggestions for better detection: | |
| - Provide longer text inputs for more reliable detection results. | |
| - Select the domain that best matches the content of your text to improve detection accuracy. | |
| """ | |
| ) | |
| # ========== 🆕 Feedback buttons (moved here for better UX) ========== | |
| st.markdown("**📝 Result Feedback**: Does this detection result meet your expectations?") | |
| current_text = text_input | |
| current_domain = selected_domain | |
| current_statistics = crit | |
| current_pvalue = p_value | |
| feedback_col1, feedback_col2 = st.columns(2) | |
| with feedback_col1: | |
| if st.button("✅ Expected", use_container_width=True, type="secondary", key=f"expected_btn_{hash(text_input[:50])}"): | |
| try: | |
| success, message = feedback_manager.save_feedback( | |
| current_text, | |
| current_domain, | |
| current_statistics, | |
| current_pvalue, | |
| 'expected' | |
| ) | |
| if success: | |
| st.success("✅ Thank you for your feedback!") | |
| st.caption(f"💾 {message}") | |
| else: | |
| st.error(f"Failed to save feedback: {message}") | |
| except Exception as e: | |
| st.error(f"Failed to save feedback: {str(e)}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| with feedback_col2: | |
| if st.button("❌ Unexpected", use_container_width=True, type="secondary", key=f"unexpected_btn_{hash(text_input[:50])}"): | |
| try: | |
| success, message = feedback_manager.save_feedback( | |
| current_text, | |
| current_domain, | |
| current_statistics, | |
| current_pvalue, | |
| 'unexpected' | |
| ) | |
| if success: | |
| st.warning("❌ Feedback recorded! This will help us improve.") | |
| st.caption(f"💾 {message}") | |
| else: | |
| st.error(f"Failed to save feedback: {message}") | |
| except Exception as e: | |
| st.error(f"Failed to save feedback: {str(e)}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| if st.session_state.feedback_given: | |
| st.success("✅ Feedback submitted successfully!") | |
| # ============================================ | |
| # Show detailed results | |
| with result_placeholder: | |
| st.caption(f"⏱️ Processing time: {elapsed_time:.2f} seconds") | |
| except Exception as e: | |
| status_placeholder.empty() | |
| st.error(f"❌ Error during detection: {str(e)}") | |
| st.exception(e) | |
| # with st.expander("📋 Citation"): | |
| # st.markdown( | |
| # """ | |
| # If you find this tool useful for you, please cite our paper: **[AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees](https://arxiv.org/abs/2510.01268)** | |
| # """ | |
| # ) | |
| # st.code( | |
| # """ | |
| # @inproceedings{zhou2024adadetectgpt, | |
| # title={AdaDetectGPT: Adaptive Detection of LLM-Generated Text with Statistical Guarantees}, | |
| # author={Hongyi Zhou and Jin Zhu and Pingfan Su and Kai Ye and Ying Yang and Shakeel A O B Gavioli-Akilagun and Chengchun Shi}, | |
| # booktitle={The Thirty-Ninth Annual Conference on Neural Information Processing Systems}, | |
| # year={2025}, | |
| # } | |
| # """, | |
| # language="bibtex" | |
| # ) | |
| # ----------------- | |
| # Footer | |
| # ----------------- | |
| st.markdown( | |
| """ | |
| <style> | |
| .footer { | |
| position: fixed; | |
| left: 0; | |
| bottom: 0; | |
| width: 100%; | |
| background-color: white; | |
| color: gray; | |
| text-align: center; | |
| padding: 1px; | |
| border-top: 1px solid #e0e0e0; | |
| z-index: 999; | |
| } | |
| /* Add padding to main content to prevent overlap with fixed footer */ | |
| .main .block-container { | |
| padding-bottom: 1px; | |
| } | |
| </style> | |
| <div class='footer'> | |
| <small> This tool is developed for research purposes only. The detection results are not 100% accurate and should not be used as the sole basis for any critical decisions. Users are advised to use this tool responsibly and ethically. </small> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) |