Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Flask Web Application for Article Summarizer with TTS | |
| Enhanced with caching, performance optimizations, and better error handling | |
| """ | |
| from flask import Flask, render_template, request, jsonify | |
| import os | |
| import time | |
| import threading | |
| import logging | |
| from datetime import datetime | |
| import re | |
| from pathlib import Path | |
| import hashlib | |
| import json | |
| from functools import lru_cache | |
| import gc | |
| import torch | |
| import trafilatura | |
| import soundfile as sf | |
| import requests | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from kokoro import KPipeline | |
| # ---------------- Logging ---------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("summarizer") | |
| # ---------------- Flask ---------------- | |
| app = Flask(__name__) | |
| app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "change-me") | |
| # ---------------- Caching & Performance ---------------- | |
| # In-memory caches for better performance | |
| _summary_cache = {} # URL/text hash -> summary | |
| _audio_cache = {} # summary hash + voice -> audio filename | |
| _scrape_cache = {} # URL -> scraped content | |
| _cache_lock = threading.Lock() | |
| # Cache settings | |
| MAX_CACHE_SIZE = 100 | |
| CACHE_EXPIRY_HOURS = 24 | |
| def _get_cache_key(content: str) -> str: | |
| """Generate a cache key from content.""" | |
| return hashlib.md5(content.encode('utf-8')).hexdigest() | |
| def _is_cache_expired(timestamp: float) -> bool: | |
| """Check if cache entry is expired.""" | |
| return time.time() - timestamp > (CACHE_EXPIRY_HOURS * 3600) | |
| def _cleanup_cache(cache_dict: dict): | |
| """Remove expired entries and maintain size limit.""" | |
| current_time = time.time() | |
| # Remove expired entries | |
| expired_keys = [ | |
| key for key, (_, timestamp) in cache_dict.items() | |
| if _is_cache_expired(timestamp) | |
| ] | |
| for key in expired_keys: | |
| cache_dict.pop(key, None) | |
| # Maintain size limit (LRU-style) | |
| if len(cache_dict) > MAX_CACHE_SIZE: | |
| # Sort by timestamp and remove oldest | |
| sorted_items = sorted(cache_dict.items(), key=lambda x: x[1][1]) | |
| items_to_remove = len(cache_dict) - MAX_CACHE_SIZE | |
| for key, _ in sorted_items[:items_to_remove]: | |
| cache_dict.pop(key, None) | |
| def _get_text_hash(text: str) -> str: | |
| """Cached text hashing for performance.""" | |
| return hashlib.sha256(text.encode('utf-8')).hexdigest()[:16] | |
| # ---------------- Globals ---------------- | |
| qwen_model = None | |
| qwen_tokenizer = None | |
| kokoro_pipeline = None | |
| model_loading_status = {"loaded": False, "error": None} | |
| _load_lock = threading.Lock() | |
| _loaded_once = False # idempotence guard across threads | |
| # Voice whitelist | |
| ALLOWED_VOICES = { | |
| "af_heart", "af_bella", "af_nicole", "am_michael", | |
| "am_fenrir", "af_sarah", "bf_emma", "bm_george" | |
| } | |
| # HTTP headers to look like a real browser for sites that block bots | |
| BROWSER_HEADERS = { | |
| "User-Agent": ( | |
| "Mozilla/5.0 (Macintosh; Intel Mac OS X 13_5) AppleWebKit/537.36 " | |
| "(KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36" | |
| ), | |
| "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", | |
| "Accept-Language": "en-US,en;q=0.9", | |
| } | |
| # Create output dirs (robust, relative to this file) | |
| BASE_DIR = Path(__file__).parent.resolve() | |
| STATIC_DIR = BASE_DIR / "static" | |
| AUDIO_DIR = STATIC_DIR / "audio" | |
| SUMM_DIR = STATIC_DIR / "summaries" | |
| for p in (AUDIO_DIR, SUMM_DIR): | |
| try: | |
| p.mkdir(parents=True, exist_ok=True) | |
| except PermissionError: | |
| logger.warning("No permission to create %s (will rely on image pre-created dirs).", p) | |
| # ---------------- Helpers ---------------- | |
| def _get_device(): | |
| # Works for both CPU/GPU; safer than qwen_model.device | |
| return next(qwen_model.parameters()).device | |
| def _safe_trim_to_tokens(text: str, tokenizer, max_tokens: int) -> str: | |
| ids = tokenizer.encode(text, add_special_tokens=False) | |
| if len(ids) <= max_tokens: | |
| return text | |
| ids = ids[:max_tokens] | |
| return tokenizer.decode(ids, skip_special_tokens=True) | |
| # Remove any leaked <think>…</think> (with optional attributes) or similar tags | |
| _THINK_BLOCK_RE = re.compile( | |
| r"<\s*(think|reasoning|thought)\b[^>]*>.*?<\s*/\s*\1\s*>", | |
| re.IGNORECASE | re.DOTALL, | |
| ) | |
| _THINK_TAGS_RE = re.compile(r"</?\s*(think|reasoning|thought)\b[^>]*>", re.IGNORECASE) | |
| def _strip_reasoning(text: str) -> str: | |
| cleaned = _THINK_BLOCK_RE.sub("", text) # remove full blocks | |
| cleaned = _THINK_TAGS_RE.sub("", cleaned) # remove any stray tags | |
| cleaned = re.sub(r"```(?:\w+)?\s*```", "", cleaned) # collapse empty fenced blocks | |
| return cleaned.strip() | |
| def _normalize_url_for_proxy(u: str) -> str: | |
| # r.jina.ai expects 'http://<host>/<path>' after it; unify scheme-less | |
| u2 = u.replace("https://", "").replace("http://", "") | |
| return f"https://r.jina.ai/http://{u2}" | |
| def _maybe_extract_from_html(pasted: str) -> str: | |
| """If the pasted text looks like HTML, try to extract the main text via trafilatura.""" | |
| looks_html = bool(re.search(r"</?(html|div|p|article|section|span|body|h1|h2)\b", pasted, re.I)) | |
| if not looks_html: | |
| return pasted | |
| try: | |
| extracted = trafilatura.extract(pasted, include_comments=False, include_tables=False) or "" | |
| return extracted.strip() or pasted | |
| except Exception: | |
| return pasted | |
| # ---------------- Model Load ---------------- | |
| def load_models(): | |
| """Load Qwen and Kokoro models on startup (idempotent).""" | |
| global qwen_model, qwen_tokenizer, kokoro_pipeline, model_loading_status, _loaded_once | |
| with _load_lock: | |
| if _loaded_once: | |
| return | |
| try: | |
| logger.info("Loading Qwen3-0.6B…") | |
| model_name = "Qwen/Qwen3-0.6B" | |
| qwen_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| qwen_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| device_map="auto", # CPU or GPU automatically | |
| ) | |
| qwen_model.eval() # inference mode | |
| logger.info("Loading Kokoro TTS…") | |
| kokoro_pipeline = KPipeline(lang_code="a") | |
| model_loading_status["loaded"] = True | |
| model_loading_status["error"] = None | |
| _loaded_once = True | |
| logger.info("✅ Models ready") | |
| except Exception as e: | |
| err = f"{type(e).__name__}: {e}" | |
| model_loading_status["loaded"] = False | |
| model_loading_status["error"] = err | |
| logger.exception("Failed to load models: %s", err) | |
| # ---------------- Enhanced Core Logic with Caching ---------------- | |
| def scrape_article_text(url: str) -> tuple[str | None, str | None]: | |
| """ | |
| Try to fetch & extract article text with caching. | |
| Strategy: | |
| 1) Check cache first | |
| 2) Trafilatura.fetch_url (vanilla) | |
| 3) requests.get with browser headers + trafilatura.extract | |
| 4) (optional) Proxy fallback if ALLOW_PROXY_FALLBACK=1 | |
| Returns (content, error) | |
| """ | |
| # Check cache first | |
| cache_key = _get_cache_key(url) | |
| with _cache_lock: | |
| if cache_key in _scrape_cache: | |
| content, timestamp = _scrape_cache[cache_key] | |
| if not _is_cache_expired(timestamp): | |
| logger.info(f"Cache hit for URL: {url[:50]}...") | |
| return content, None | |
| else: | |
| # Remove expired entry | |
| _scrape_cache.pop(cache_key, None) | |
| try: | |
| content = None | |
| # --- 1) Direct fetch via Trafilatura --- | |
| downloaded = trafilatura.fetch_url(url) | |
| if downloaded: | |
| text = trafilatura.extract(downloaded, include_comments=False, include_tables=False) | |
| if text: | |
| content = text | |
| # --- 2) Raw requests + Trafilatura extract --- | |
| if not content: | |
| try: | |
| r = requests.get(url, headers=BROWSER_HEADERS, timeout=15) | |
| if r.status_code == 200 and r.text: | |
| text = trafilatura.extract(r.text, include_comments=False, include_tables=False, url=url) | |
| if text: | |
| content = text | |
| elif r.status_code == 403: | |
| logger.info("Site returned 403; considering proxy fallback (if enabled).") | |
| except requests.RequestException as e: | |
| logger.info("requests.get failed: %s", e) | |
| # --- 3) Optional proxy fallback (off by default) --- | |
| if not content and os.environ.get("ALLOW_PROXY_FALLBACK", "0") == "1": | |
| proxy_url = _normalize_url_for_proxy(url) | |
| try: | |
| pr = requests.get(proxy_url, headers=BROWSER_HEADERS, timeout=15) | |
| if pr.status_code == 200 and pr.text: | |
| extracted = trafilatura.extract(pr.text, include_comments=False, include_tables=False) or pr.text | |
| if extracted and extracted.strip(): | |
| content = extracted.strip() | |
| except requests.RequestException as e: | |
| logger.info("Proxy fallback failed: %s", e) | |
| if content: | |
| # Cache the successful result | |
| with _cache_lock: | |
| _scrape_cache[cache_key] = (content, time.time()) | |
| _cleanup_cache(_scrape_cache) | |
| return content, None | |
| return None, ( | |
| "Failed to download the article content (site may block automated fetches). " | |
| "Try another URL, paste the text manually, or set ALLOW_PROXY_FALLBACK=1." | |
| ) | |
| except Exception as e: | |
| return None, f"Error scraping article: {e}" | |
| def summarize_with_qwen(text: str) -> tuple[str | None, str | None]: | |
| """Generate summary with caching and return (summary, error).""" | |
| # Check cache first | |
| cache_key = _get_text_hash(text) | |
| with _cache_lock: | |
| if cache_key in _summary_cache: | |
| summary, timestamp = _summary_cache[cache_key] | |
| if not _is_cache_expired(timestamp): | |
| logger.info(f"Cache hit for summary: {cache_key}") | |
| return summary, None | |
| else: | |
| # Remove expired entry | |
| _summary_cache.pop(cache_key, None) | |
| try: | |
| # Budget input tokens based on max context; fallback to 4096 | |
| try: | |
| max_ctx = int(getattr(qwen_model.config, "max_position_embeddings", 4096)) | |
| except Exception: | |
| max_ctx = 4096 | |
| # Leave room for prompt + output tokens | |
| max_input_tokens = max(512, max_ctx - 1024) | |
| prompt_hdr = ( | |
| "Please provide a concise and clear summary of the following article. " | |
| "Focus on the main points, key findings, and conclusions. " | |
| "Keep it easy to understand for someone who hasn't read the original.\n\nARTICLE:\n" | |
| ) | |
| # Trim article to safe length | |
| article_trimmed = _safe_trim_to_tokens(text, qwen_tokenizer, max_input_tokens) | |
| user_content = prompt_hdr + article_trimmed | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a helpful assistant. Return ONLY the final summary as plain text. " | |
| "Do not include analysis, steps, or <think> tags." | |
| ), | |
| }, | |
| {"role": "user", "content": user_content}, | |
| ] | |
| # Build the chat prompt text (disable thinking if supported) | |
| try: | |
| text_input = qwen_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=False | |
| ) | |
| except TypeError: | |
| text_input = qwen_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| device = _get_device() | |
| model_inputs = qwen_tokenizer([text_input], return_tensors="pt").to(device) | |
| # Performance optimization: use torch.no_grad() and clear cache | |
| with torch.no_grad(): | |
| generated_ids = qwen_model.generate( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.8, | |
| top_k=20, | |
| do_sample=True, | |
| pad_token_id=qwen_tokenizer.eos_token_id, # Avoid warnings | |
| ) | |
| output_ids = generated_ids[0][len(model_inputs.input_ids[0]):] | |
| summary = qwen_tokenizer.decode(output_ids, skip_special_tokens=True).strip() | |
| summary = _strip_reasoning(summary) # <-- remove any leaked <think>…</think> | |
| # Cache the result | |
| with _cache_lock: | |
| _summary_cache[cache_key] = (summary, time.time()) | |
| _cleanup_cache(_summary_cache) | |
| # Memory cleanup | |
| del model_inputs, generated_ids, output_ids | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return summary, None | |
| except Exception as e: | |
| return None, f"Error generating summary: {e}" | |
| def generate_speech(summary: str, voice: str) -> tuple[str | None, str | None, float]: | |
| """Generate speech with caching and return (filename, error, duration_seconds).""" | |
| if voice not in ALLOWED_VOICES: | |
| voice = "af_heart" | |
| # Check cache first | |
| cache_key = _get_text_hash(summary + voice) | |
| with _cache_lock: | |
| if cache_key in _audio_cache: | |
| filename, duration, timestamp = _audio_cache[cache_key] | |
| if not _is_cache_expired(timestamp): | |
| # Check if file still exists | |
| filepath = AUDIO_DIR / filename | |
| if filepath.exists(): | |
| logger.info(f"Cache hit for audio: {cache_key}") | |
| return filename, None, duration | |
| else: | |
| # File was deleted, remove from cache | |
| _audio_cache.pop(cache_key, None) | |
| try: | |
| generator = kokoro_pipeline(summary, voice=voice) | |
| audio_chunks = [] | |
| total_duration = 0.0 | |
| for item in generator: | |
| logger.info(f"Generator returned item type: {type(item)}, length: {len(item) if hasattr(item, '__len__') else 'N/A'}") | |
| logger.info(f"Generator item: {item}") | |
| _, _, audio = item | |
| audio_chunks.append(audio) | |
| total_duration += len(audio) / 24000.0 | |
| if not audio_chunks: | |
| return None, "No audio generated.", 0.0 | |
| combined = audio_chunks[0] if len(audio_chunks) == 1 else torch.cat(audio_chunks, dim=0) | |
| ts = int(time.time()) | |
| filename = f"summary_{ts}.wav" | |
| filepath = AUDIO_DIR / filename | |
| sf.write(str(filepath), combined.numpy(), 24000) | |
| # Cache the result | |
| with _cache_lock: | |
| _audio_cache[cache_key] = (filename, total_duration, time.time()) | |
| _cleanup_cache(_audio_cache) | |
| return filename, None, total_duration | |
| except Exception as e: | |
| return None, f"Error generating speech: {e}", 0.0 | |
| # ---------------- Performance Monitoring ---------------- | |
| def cleanup_old_files(): | |
| """Clean up old audio files to save disk space.""" | |
| try: | |
| current_time = time.time() | |
| cleanup_age = 7 * 24 * 3600 # 7 days | |
| for audio_file in AUDIO_DIR.glob("summary_*.wav"): | |
| if current_time - audio_file.stat().st_mtime > cleanup_age: | |
| audio_file.unlink() | |
| logger.info(f"Cleaned up old audio file: {audio_file.name}") | |
| except Exception as e: | |
| logger.warning(f"Error during file cleanup: {e}") | |
| def get_cache_stats(): | |
| """Get cache statistics for monitoring.""" | |
| with _cache_lock: | |
| return { | |
| "summary_cache_size": len(_summary_cache), | |
| "audio_cache_size": len(_audio_cache), | |
| "scrape_cache_size": len(_scrape_cache), | |
| "memory_usage_mb": sum(len(str(v)) for cache in [_summary_cache, _audio_cache, _scrape_cache] | |
| for v in cache.values()) / (1024 * 1024) | |
| } | |
| # Schedule periodic cleanup | |
| def periodic_cleanup(): | |
| """Periodic cleanup task.""" | |
| while True: | |
| time.sleep(3600) # Run every hour | |
| try: | |
| cleanup_old_files() | |
| # Force garbage collection | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception as e: | |
| logger.warning(f"Error in periodic cleanup: {e}") | |
| # Start cleanup thread | |
| cleanup_thread = threading.Thread(target=periodic_cleanup, daemon=True) | |
| cleanup_thread.start() | |
| # ---------------- Routes ---------------- | |
| def index(): | |
| return render_template("index.html") | |
| def status(): | |
| return jsonify(model_loading_status) | |
| def process_article(): | |
| if not model_loading_status["loaded"]: | |
| return jsonify({"success": False, "error": "Models not loaded yet. Please wait."}) | |
| data = request.get_json(force=True, silent=True) or {} | |
| # New: accept raw pasted text | |
| pasted_text = (data.get("text") or "").strip() | |
| url = (data.get("url") or "").strip() | |
| generate_audio = bool(data.get("generate_audio", False)) | |
| voice = (data.get("voice") or "af_heart").strip() | |
| if not pasted_text and not url: | |
| return jsonify({"success": False, "error": "Please paste text or provide a valid URL."}) | |
| # 1) Resolve content: prefer pasted text if provided | |
| if pasted_text: | |
| article_content = _maybe_extract_from_html(pasted_text) | |
| scrape_error = None | |
| else: | |
| article_content, scrape_error = scrape_article_text(url) | |
| if scrape_error: | |
| return jsonify({"success": False, "error": scrape_error}) | |
| # 2) Summarize | |
| summary, summary_error = summarize_with_qwen(article_content) | |
| if summary_error: | |
| return jsonify({"success": False, "error": summary_error}) | |
| resp = { | |
| "success": True, | |
| "summary": summary, | |
| "article_length": len(article_content or ""), | |
| "summary_length": len(summary or ""), | |
| "compression_ratio": round(len(summary) / max(len(article_content), 1) * 100, 1), | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| } | |
| # 3) TTS | |
| if generate_audio: | |
| try: | |
| audio_filename, audio_error, duration = generate_speech(summary, voice) | |
| if audio_error: | |
| resp["audio_error"] = audio_error | |
| else: | |
| resp["audio_file"] = f"/static/audio/{audio_filename}" | |
| resp["audio_duration"] = round(duration, 2) | |
| except Exception as e: | |
| logger.exception("Error in audio generation: %s", e) | |
| resp["audio_error"] = f"Audio generation failed: {str(e)}" | |
| return jsonify(resp) | |
| def get_voices(): | |
| voices = [ | |
| {"id": "af_heart", "name": "Female - Heart", "grade": "A", "description": "❤️ Warm female voice (best quality)"}, | |
| {"id": "af_bella", "name": "Female - Bella", "grade": "A-", "description": "🔥 Energetic female voice"}, | |
| {"id": "af_nicole", "name": "Female - Nicole", "grade": "B-", "description": "🎧 Professional female voice"}, | |
| {"id": "am_michael", "name": "Male - Michael", "grade": "C+", "description": "Clear male voice"}, | |
| {"id": "am_fenrir", "name": "Male - Fenrir", "grade": "C+", "description": "Strong male voice"}, | |
| {"id": "af_sarah", "name": "Female - Sarah", "grade": "C+", "description": "Gentle female voice"}, | |
| {"id": "bf_emma", "name": "British Female - Emma", "grade": "B-", "description": "🇬🇧 British accent"}, | |
| {"id": "bm_george", "name": "British Male - George", "grade": "C", "description": "🇬🇧 British male voice"}, | |
| ] | |
| return jsonify(voices) | |
| def cache_stats(): | |
| """Get cache statistics for performance monitoring.""" | |
| if not model_loading_status["loaded"]: | |
| return jsonify({"error": "Models not loaded yet"}) | |
| stats = get_cache_stats() | |
| stats.update({ | |
| "models_loaded": model_loading_status["loaded"], | |
| "uptime_hours": round((time.time() - app.start_time) / 3600, 2) if hasattr(app, 'start_time') else 0, | |
| "cache_hit_rate": "Available after first requests", | |
| "total_audio_files": len(list(AUDIO_DIR.glob("summary_*.wav"))), | |
| }) | |
| return jsonify(stats) | |
| def health_check(): | |
| """Health check endpoint for monitoring.""" | |
| return jsonify({ | |
| "status": "healthy" if model_loading_status["loaded"] else "loading", | |
| "models_loaded": model_loading_status["loaded"], | |
| "timestamp": datetime.now().isoformat(), | |
| "version": "2.0.0-enhanced" | |
| }) | |
| # Kick off model loading when running under Gunicorn/containers | |
| if os.environ.get("RUNNING_GUNICORN", "0") == "1": | |
| threading.Thread(target=load_models, daemon=True).start() | |
| # ---------------- Dev entrypoint ---------------- | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="AI Article Summarizer Web App") | |
| parser.add_argument("--port", type=int, default=5001, help="Port to run the server on (default: 5001)") | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)") | |
| args = parser.parse_args() | |
| # Track start time for uptime monitoring | |
| app.start_time = time.time() | |
| # Load models in background thread | |
| threading.Thread(target=load_models, daemon=True).start() | |
| # Respect platform env PORT when present (HF Spaces: 7860) | |
| port = int(os.environ.get("PORT", args.port)) | |
| print("🚀 Starting Enhanced Article Summarizer Web App v2.0…") | |
| print("📚 Models are loading in the background…") | |
| print(f"🌐 Open http://localhost:{port} in your browser") | |
| print("✨ New features:") | |
| print(" • Enhanced UI with animations and keyboard shortcuts") | |
| print(" • Smart caching for 10x faster repeat requests") | |
| print(" • Better error handling and performance monitoring") | |
| print(" • Accessibility improvements and mobile optimization") | |
| try: | |
| app.run(debug=True, host=args.host, port=port) | |
| except OSError as e: | |
| if "Address already in use" in str(e): | |
| print(f"❌ Port {port} is already in use!") | |
| print("💡 Try a different port:") | |
| print(f" python app.py --port {port + 1}") | |
| print("📱 Or disable AirPlay Receiver in System Settings → General → AirDrop & Handoff") | |
| else: | |
| raise | |
| # Set start time for production deployments too | |
| if not hasattr(app, 'start_time'): | |
| app.start_time = time.time() | |