Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field, field_validator, ConfigDict | |
| from contextlib import asynccontextmanager | |
| from optimum.onnxruntime import ORTModelForSeq2SeqLM | |
| from transformers import AutoTokenizer | |
| from difflib import SequenceMatcher | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| import onnxruntime as ort | |
| import uvicorn | |
| import os | |
| import time | |
| import logging | |
| from typing import List, Tuple | |
| import re | |
| import uuid | |
| from datetime import datetime | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global model variables | |
| model = None | |
| tokenizer = None | |
| start_time = None # Track server uptime | |
| # Configuration | |
| VERSION = "1.0.0" | |
| MODEL_PATH = "IDK100boysaj/coedit-xl-onnx-8bit" | |
| MAX_INPUT_TOKENS = 512 # CoEdit-XL hard limit (including prompt) | |
| MAX_OUTPUT_TOKENS = 512 # Fixed at 512 as requested | |
| OVERLAP_SENTENCES = 2 # Fixed 2-sentence overlap for adaptive merging | |
| MAX_TEXT_LENGTH = 50000 # Character limit to prevent memory issues | |
| BATCH_SIZE = 2 # Conservative batch size for 2vCPU + 16GB RAM | |
| # Rate limiting configuration | |
| limiter = Limiter(key_func=get_remote_address) | |
| RATE_LIMIT = "100/minute" # 100 requests per minute per IP | |
| async def lifespan(app: FastAPI): | |
| """Load model on startup, cleanup on shutdown""" | |
| global model, tokenizer, start_time | |
| start_time = datetime.now() | |
| logger.info(f"Starting CoEdit-XL API v{VERSION}...") | |
| logger.info(f"Loading CoEdit-XL quantized ONNX model from {MODEL_PATH}...") | |
| try: | |
| # Configure ONNX Runtime session options for 2vCPU optimization | |
| session_options = ort.SessionOptions() | |
| session_options.intra_op_num_threads = 2 # Match vCPU count | |
| session_options.inter_op_num_threads = 1 # Single thread for inter-op | |
| session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL # Sequential execution | |
| session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| # Disable thread spinning to save CPU cycles on wait | |
| session_options.add_session_config_entry("session.intra_op.allow_spinning", "0") | |
| session_options.add_session_config_entry("session.inter_op.allow_spinning", "0") | |
| logger.info("Configured ONNX session: 2 intra-op threads, no spinning, sequential execution") | |
| # Load quantized ONNX model with Optimum | |
| # Explicitly specify file names for non-standard naming | |
| model = ORTModelForSeq2SeqLM.from_pretrained( | |
| MODEL_PATH, | |
| provider="CPUExecutionProvider", # Force CPU execution | |
| session_options=session_options, # Add optimized session options | |
| encoder_file_name="encoder_model.onnx", | |
| decoder_file_name="decoder_model.onnx", | |
| decoder_with_past_file_name="decoder_with_past_model.onnx", | |
| ) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_PATH, | |
| ) | |
| logger.info("✓ Quantized ONNX model loaded successfully on CPU") | |
| logger.info(f"Optimizations: Sentence-based chunking | Fixed {OVERLAP_SENTENCES}-sentence overlap | Batch size: {BATCH_SIZE}") | |
| logger.info(f"Adaptive merging with 0.9 similarity threshold") | |
| logger.info(f"Max input tokens: {MAX_INPUT_TOKENS} | Max output tokens: {MAX_OUTPUT_TOKENS}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise RuntimeError(f"Model loading failed: {str(e)}") | |
| yield # Server runs here | |
| # Cleanup on shutdown | |
| logger.info("Shutting down...") | |
| model = None | |
| tokenizer = None | |
| app = FastAPI( | |
| title="CoEdit-XL Grammar Correction API", | |
| version=VERSION, | |
| lifespan=lifespan, | |
| description="High-performance grammar correction API using quantized CoEdit-XL with rolling window support", | |
| ) | |
| # Add rate limiting | |
| app.state.limiter = limiter | |
| async def rate_limit_handler(request: Request, exc: RateLimitExceeded): | |
| """Custom rate limit exceeded handler""" | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| "detail": "Rate limit exceeded. Maximum 100 requests per minute.", | |
| "retry_after": "60 seconds" | |
| } | |
| ) | |
| # CORS middleware for cross-origin requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class CorrectionRequest(BaseModel): | |
| model_config = ConfigDict( | |
| json_schema_extra={ | |
| "example": { | |
| "text": "When I grow up, I start to understand what he said is quite right.", | |
| "prompt": "Fix grammatical errors in this sentence:" | |
| } | |
| } | |
| ) | |
| text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH, description="Text to correct") | |
| prompt: str = Field("Fix grammatical errors in this sentence:", description="Task prompt (e.g., 'Fix grammatical errors:', 'Make this text more formal:', 'Paraphrase this:')") | |
| def text_not_empty(cls, v: str) -> str: | |
| if not v.strip(): | |
| raise ValueError('Text cannot be empty or only whitespace') | |
| return v | |
| def prompt_not_empty(cls, v: str) -> str: | |
| if not v.strip(): | |
| raise ValueError('Prompt cannot be empty or only whitespace') | |
| return v.strip() | |
| class CorrectionResponse(BaseModel): | |
| original: str | |
| corrected: str | |
| num_chunks: int | |
| processing_time_seconds: float | |
| changes_made: bool | |
| model: str = "CoEdit-XL (Quantized ONNX)" | |
| def split_into_complete_sentences(text: str) -> List[str]: | |
| """ | |
| Split text into complete sentences with proper punctuation handling. | |
| Handles: | |
| - Sentence terminators: . ? ! (and standalone semicolons) | |
| - Abbreviations: Dr., Mr., Mrs., Ms., Ph.D., etc., i.e., e.g., vs., Inc., Ltd. | |
| - Semicolons in lists (doesn't split on these) | |
| - Preserves punctuation with each sentence | |
| Examples: | |
| "Dr. Smith likes apples; oranges; milk. He is happy!" | |
| → ["Dr. Smith likes apples; oranges; milk.", "He is happy!"] | |
| "I love it. They are great!" | |
| → ["I love it.", "They are great!"] | |
| """ | |
| # Common abbreviations that end with period (not sentence boundaries) | |
| abbreviations = [ | |
| 'Dr', 'Mr', 'Mrs', 'Ms', 'Prof', 'Sr', 'Jr', | |
| 'Ph.D', 'M.D', 'B.A', 'M.A', 'Ph.D.', 'M.D.', 'B.A.', 'M.A.', | |
| 'etc', 'vs', 'Inc', 'Ltd', 'Co', 'Corp', | |
| 'i.e', 'e.g', 'a.m', 'p.m', 'U.S', 'U.K', | |
| 'i.e.', 'e.g.', 'a.m.', 'p.m.', 'U.S.', 'U.K.' | |
| ] | |
| # Split on sentence boundaries: [.!?] followed by space and capital letter | |
| # But check it's not an abbreviation first | |
| sentences = [] | |
| current = [] | |
| words = text.split() | |
| for i, word in enumerate(words): | |
| current.append(word) | |
| # Check if this word ends with a sentence terminator | |
| if word and word[-1] in '.!?': | |
| # Check if it's an abbreviation | |
| is_abbrev = any(word.rstrip('.!?').endswith(abbr.rstrip('.')) for abbr in abbreviations) | |
| # If not abbreviation and next word starts with capital, it's a sentence boundary | |
| if not is_abbrev and i + 1 < len(words) and words[i + 1] and words[i + 1][0].isupper(): | |
| sentences.append(' '.join(current)) | |
| current = [] | |
| # Add remaining words | |
| if current: | |
| sentences.append(' '.join(current)) | |
| # Handle semicolons that act as sentence separators (not in lists) | |
| # A semicolon is a separator if followed by space and capital letter | |
| result = [] | |
| for sent in sentences: | |
| # Check if this sentence contains standalone semicolons (sentence separators) | |
| # Pattern: semicolon followed by space and capital letter, not in a list context | |
| # List context: has colons before semicolons (e.g., "buy: apples; oranges") | |
| if ';' in sent: | |
| # Check if it's a list (has colon before semicolons) | |
| colon_pos = sent.find(':') | |
| semicolon_positions = [i for i, c in enumerate(sent) if c == ';'] | |
| # If semicolons come after a colon, it's likely a list - don't split | |
| is_list = colon_pos >= 0 and any(sp > colon_pos for sp in semicolon_positions) | |
| if not is_list: | |
| # Split on semicolons followed by space and capital | |
| subsents = re.split(r';\s+(?=[A-Z])', sent) | |
| # Add semicolon back to each part except last | |
| for i, ss in enumerate(subsents[:-1]): | |
| result.append(ss + ';') | |
| if subsents[-1].strip(): | |
| result.append(subsents[-1]) | |
| else: | |
| result.append(sent) | |
| else: | |
| result.append(sent) | |
| # Clean up and filter | |
| return [s.strip() for s in result if s.strip()] | |
| def chunk_by_fixed_overlap(text: str, tokenizer, max_tokens: int = MAX_INPUT_TOKENS, overlap_sentences: int = OVERLAP_SENTENCES) -> Tuple[List[str], List[int]]: | |
| """ | |
| Split text into chunks with FIXED sentence overlap (not percentage). | |
| Returns: | |
| chunks: List of text chunks | |
| overlap_counts: List indicating how many sentences in each chunk overlap from previous chunk | |
| Example with overlap_sentences=2: | |
| Sentences: ["A.", "B.", "C.", "D.", "E."] | |
| Chunks: | |
| - Chunk 0: ["A.", "B.", "C."] → overlap_count=0 | |
| - Chunk 1: ["B.", "C.", "D."] → overlap_count=2 | |
| - Chunk 2: ["C.", "D.", "E."] → overlap_count=2 | |
| """ | |
| # Tokenize full text to count tokens | |
| tokens = tokenizer.encode(text, add_special_tokens=False) | |
| if len(tokens) <= max_tokens: | |
| return ([text], [0]) | |
| # Split into sentences first (better for CoEdit context) | |
| sentences = split_into_complete_sentences(text) | |
| if len(sentences) == 0: | |
| return ([text], [0]) | |
| chunks = [] | |
| overlap_counts = [] | |
| sentence_idx = 0 | |
| while sentence_idx < len(sentences): | |
| current_chunk_sentences = [] | |
| current_tokens = 0 | |
| chunk_start_idx = sentence_idx | |
| # Determine overlap count for this chunk | |
| if len(chunks) == 0: | |
| # First chunk has no overlap | |
| overlap_count = 0 | |
| else: | |
| # Subsequent chunks: go back by overlap_sentences | |
| overlap_count = min(overlap_sentences, sentence_idx) | |
| chunk_start_idx = sentence_idx - overlap_count | |
| # Build chunk starting from chunk_start_idx | |
| for i in range(chunk_start_idx, len(sentences)): | |
| sentence = sentences[i] | |
| sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) | |
| sentence_token_count = len(sentence_tokens) | |
| # Check if adding this sentence would exceed token limit | |
| if current_tokens + sentence_token_count > max_tokens and current_chunk_sentences: | |
| # Chunk is full, stop here | |
| break | |
| current_chunk_sentences.append(sentence) | |
| current_tokens += sentence_token_count | |
| # If this is a new sentence (not overlap), advance the index | |
| if i >= sentence_idx: | |
| sentence_idx = i + 1 | |
| # Save the chunk | |
| if current_chunk_sentences: | |
| chunks.append(' '.join(current_chunk_sentences)) | |
| overlap_counts.append(overlap_count) | |
| else: | |
| # Edge case: single sentence too long, split by words | |
| long_sentence = sentences[chunk_start_idx] | |
| words = long_sentence.split() | |
| word_chunk = [] | |
| word_tokens = 0 | |
| for word in words: | |
| word_token_count = len(tokenizer.encode(word, add_special_tokens=False)) | |
| if word_tokens + word_token_count > max_tokens and word_chunk: | |
| chunks.append(' '.join(word_chunk)) | |
| overlap_counts.append(0) # Word-level chunks don't have sentence overlap | |
| word_chunk = [] | |
| word_tokens = 0 | |
| word_chunk.append(word) | |
| word_tokens += word_token_count | |
| if word_chunk: | |
| chunks.append(' '.join(word_chunk)) | |
| overlap_counts.append(0) | |
| sentence_idx += 1 | |
| return (chunks, overlap_counts) | |
| def merge_with_adaptive_overlap( | |
| corrected_chunks: List[str], | |
| overlap_counts: List[int], | |
| similarity_threshold: float = 0.9 | |
| ) -> str: | |
| """ | |
| Merge corrected chunks using adaptive overlap strategy. | |
| Compares overlapping sentences and: | |
| - If similarity >= threshold: Keep previous version (consistency) | |
| - If similarity < threshold: Use current version (better context) | |
| Args: | |
| corrected_chunks: List of corrected text chunks | |
| overlap_counts: Number of sentences in each chunk that overlap from previous | |
| similarity_threshold: Minimum similarity to keep previous version (default 0.9) | |
| Returns: | |
| Merged corrected text | |
| Example: | |
| Chunk 1: "I like apples. They are delicious." | |
| Chunk 2 (overlap=1): "They are delicious. My friend agrees." | |
| → If both have "They are delicious." (similarity=1.0), keep chunk 1's version | |
| → Add "My friend agrees." | |
| """ | |
| if len(corrected_chunks) == 0: | |
| return "" | |
| if len(corrected_chunks) == 1: | |
| return corrected_chunks[0] | |
| # Start with first chunk (no overlap) | |
| result_sentences = split_into_complete_sentences(corrected_chunks[0]) | |
| for chunk_idx in range(1, len(corrected_chunks)): | |
| current_chunk = corrected_chunks[chunk_idx] | |
| overlap_count = overlap_counts[chunk_idx] | |
| # Split current chunk into sentences | |
| current_sentences = split_into_complete_sentences(current_chunk) | |
| if overlap_count == 0: | |
| # No overlap, just append all sentences | |
| result_sentences.extend(current_sentences) | |
| continue | |
| # Compare overlapping sentences | |
| # Get last N sentences from result (where N = overlap_count) | |
| comparison_count = min(overlap_count, len(result_sentences)) | |
| previous_overlap_sents = result_sentences[-comparison_count:] if comparison_count > 0 else [] | |
| # Get first N sentences from current chunk | |
| current_overlap_sents = current_sentences[:min(overlap_count, len(current_sentences))] | |
| # Compare each overlapping sentence | |
| sentences_to_replace = [] | |
| for i in range(min(len(previous_overlap_sents), len(current_overlap_sents))): | |
| prev_sent = previous_overlap_sents[i] | |
| curr_sent = current_overlap_sents[i] | |
| # Calculate similarity | |
| similarity = SequenceMatcher(None, prev_sent.lower(), curr_sent.lower()).ratio() | |
| logger.debug(f"Comparing overlap {i+1}: similarity={similarity:.3f}") | |
| logger.debug(f" Previous: {prev_sent[:80]}...") | |
| logger.debug(f" Current: {curr_sent[:80]}...") | |
| if similarity < similarity_threshold: | |
| # Current version is significantly different, use it (better context) | |
| sentences_to_replace.append((i, curr_sent)) | |
| logger.debug(f" → Using current version (better context)") | |
| else: | |
| # Keep previous version (consistency) | |
| logger.debug(f" → Keeping previous version (consistent)") | |
| # Replace sentences if needed | |
| for idx, new_sent in sentences_to_replace: | |
| result_idx = len(result_sentences) - comparison_count + idx | |
| if 0 <= result_idx < len(result_sentences): | |
| result_sentences[result_idx] = new_sent | |
| # Add non-overlapping sentences from current chunk | |
| non_overlap_start = min(overlap_count, len(current_sentences)) | |
| if non_overlap_start < len(current_sentences): | |
| result_sentences.extend(current_sentences[non_overlap_start:]) | |
| return ' '.join(result_sentences) | |
| def correct_text( | |
| text: str, | |
| prompt: str = "Fix grammatical errors in this sentence:" | |
| ) -> Tuple[str, int]: | |
| """ | |
| Correct text using CoEdit-XL with rolling window approach. | |
| Returns: (corrected_text, num_chunks) | |
| CoEdit format: "{prompt} {text}" | |
| Example prompts: | |
| - "Fix grammatical errors in this sentence:" | |
| - "Make this text more formal:" | |
| - "Paraphrase this:" | |
| - "Improve the clarity of this text:" | |
| """ | |
| if not text.strip(): | |
| return text, 0 | |
| # Safety check - should never happen due to endpoint guard | |
| if tokenizer is None or model is None: | |
| raise RuntimeError("Model or tokenizer not initialized") | |
| # Ensure prompt ends with space or colon for proper formatting | |
| if not prompt.endswith((' ', ':')): | |
| prompt = prompt + ' ' | |
| elif prompt.endswith(':') and not prompt.endswith(': '): | |
| prompt = prompt + ' ' | |
| # Calculate max tokens for text content (reserve tokens for prompt) | |
| prompt_tokens = len(tokenizer.encode(prompt, add_special_tokens=False)) | |
| max_text_tokens = MAX_INPUT_TOKENS - prompt_tokens - 2 # -2 for special tokens | |
| if max_text_tokens < 100: | |
| logger.warning(f"Prompt is too long ({prompt_tokens} tokens), may affect chunking") | |
| max_text_tokens = 100 # Minimum text tokens | |
| # Split text into manageable chunks with FIXED sentence overlap | |
| original_chunks, overlap_counts = chunk_by_fixed_overlap(text, tokenizer, max_text_tokens, overlap_sentences=OVERLAP_SENTENCES) | |
| logger.info(f"Split text into {len(original_chunks)} chunks with fixed {OVERLAP_SENTENCES}-sentence overlap") | |
| logger.info(f"Overlap counts: {overlap_counts}") | |
| corrected_chunks = [] | |
| # Process chunks in batches for better performance | |
| batch_size = min(BATCH_SIZE, len(original_chunks)) | |
| for batch_start in range(0, len(original_chunks), batch_size): | |
| batch_end = min(batch_start + batch_size, len(original_chunks)) | |
| batch = original_chunks[batch_start:batch_end] | |
| logger.info(f"Processing batch {batch_start//batch_size + 1} (chunks {batch_start+1}-{batch_end}/{len(original_chunks)})") | |
| try: | |
| # Prepare inputs with user-provided prompt (CRITICAL: prompt repeats for each chunk!) | |
| prompted_batch = [f"{prompt}{chunk}" for chunk in batch] | |
| # Tokenize batch | |
| inputs = tokenizer( | |
| prompted_batch, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=MAX_INPUT_TOKENS | |
| ) | |
| # Generate corrections with fixed max_length=512 | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=MAX_OUTPUT_TOKENS, | |
| early_stopping=True, | |
| ) | |
| # Decode outputs (prompt is removed by skip_special_tokens=True) | |
| corrected_batch = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| corrected_chunks.extend(corrected_batch) | |
| except Exception as e: | |
| logger.error(f"Error processing batch: {str(e)}") | |
| # Fallback to original chunks if correction fails | |
| corrected_chunks.extend(batch) | |
| # Merge corrected chunks with adaptive overlap strategy | |
| final_corrected = merge_with_adaptive_overlap(corrected_chunks, overlap_counts, similarity_threshold=0.9) | |
| return final_corrected, len(original_chunks) | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "online", | |
| "service": "CoEdit-XL Grammar Correction API", | |
| "model": "CoEdit-XL (Quantized ONNX)", | |
| "model_path": MODEL_PATH, | |
| "max_input_tokens": MAX_INPUT_TOKENS, | |
| "max_output_tokens": MAX_OUTPUT_TOKENS, | |
| "overlap_sentences": OVERLAP_SENTENCES, | |
| "adaptive_merging": True, | |
| "similarity_threshold": 0.9, | |
| "max_text_length": MAX_TEXT_LENGTH, | |
| "example_prompts": [ | |
| "Fix grammatical errors in this sentence:", | |
| "Make this text more formal:", | |
| "Paraphrase this:", | |
| "Improve the clarity of this text:", | |
| "Neutralize this text:", | |
| "Make this easier to understand:" | |
| ], | |
| "endpoints": { | |
| "correction": "POST /correct", | |
| "health": "GET /health", | |
| } | |
| } | |
| async def health(): | |
| """Enhanced health check with uptime and system metrics""" | |
| try: | |
| import psutil | |
| uptime_seconds = (datetime.now() - start_time).total_seconds() if start_time else 0 | |
| cpu_percent = psutil.cpu_percent(interval=0.1) | |
| mem = psutil.virtual_memory() | |
| return { | |
| "status": "healthy", | |
| "version": VERSION, | |
| "uptime_seconds": round(uptime_seconds, 2), | |
| "uptime_human": str(datetime.now() - start_time).split('.')[0] if start_time else "N/A", | |
| "model_loaded": model is not None, | |
| "tokenizer_loaded": tokenizer is not None, | |
| "device": "CPU (ONNX Runtime)", | |
| "configuration": { | |
| "max_input_tokens": MAX_INPUT_TOKENS, | |
| "max_output_tokens": MAX_OUTPUT_TOKENS, | |
| "overlap_sentences": OVERLAP_SENTENCES, | |
| "max_text_length": MAX_TEXT_LENGTH, | |
| "batch_size": BATCH_SIZE, | |
| "rate_limit": RATE_LIMIT | |
| }, | |
| "system": { | |
| "cpu_usage_percent": cpu_percent, | |
| "memory_used_gb": round(mem.used / (1024**3), 2), | |
| "memory_available_gb": round(mem.available / (1024**3), 2), | |
| "memory_percent": mem.percent | |
| } | |
| } | |
| except ImportError: | |
| uptime_seconds = (datetime.now() - start_time).total_seconds() if start_time else 0 | |
| return { | |
| "status": "healthy", | |
| "version": VERSION, | |
| "uptime_seconds": round(uptime_seconds, 2), | |
| "uptime_human": str(datetime.now() - start_time).split('.')[0] if start_time else "N/A", | |
| "model_loaded": model is not None, | |
| "tokenizer_loaded": tokenizer is not None, | |
| "device": "CPU (ONNX Runtime)", | |
| "configuration": { | |
| "max_input_tokens": MAX_INPUT_TOKENS, | |
| "max_output_tokens": MAX_OUTPUT_TOKENS, | |
| "overlap_sentences": OVERLAP_SENTENCES, | |
| "max_text_length": MAX_TEXT_LENGTH, | |
| "batch_size": BATCH_SIZE, | |
| "rate_limit": RATE_LIMIT | |
| }, | |
| "note": "Install psutil for system metrics" | |
| } | |
| async def correct_grammar(request_obj: Request, request: CorrectionRequest): | |
| """ | |
| Correct/edit text using CoEdit-XL with custom prompts. | |
| Handles texts longer than 512 tokens using rolling window approach. | |
| Rate limited to 100 requests per minute per IP. | |
| Example prompts: | |
| - "Fix grammatical errors in this sentence:" | |
| - "Make this text more formal:" | |
| - "Paraphrase this:" | |
| - "Improve the clarity of this text:" | |
| """ | |
| request_id = str(uuid.uuid4())[:8] | |
| if model is None or tokenizer is None: | |
| logger.error(f"[{request_id}] Model not loaded") | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Service temporarily unavailable. Model not loaded." | |
| ) | |
| # Validate text length (Pydantic already checks, but double-check for safety) | |
| if len(request.text) > MAX_TEXT_LENGTH: | |
| logger.warning(f"[{request_id}] Text too large: {len(request.text)} chars") | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"Text too large. Maximum allowed: {MAX_TEXT_LENGTH} characters." | |
| ) | |
| try: | |
| start_time_req = time.time() | |
| logger.info(f"[{request_id}] Correction request | Text: {len(request.text)} chars | Prompt: '{request.prompt[:50]}...'") | |
| corrected, num_chunks = correct_text( | |
| request.text, | |
| request.prompt | |
| ) | |
| processing_time = time.time() - start_time_req | |
| changes_made = corrected.strip() != request.text.strip() | |
| logger.info(f"[{request_id}] Completed in {processing_time:.2f}s | Chunks: {num_chunks} | Changes: {changes_made}") | |
| return CorrectionResponse( | |
| original=request.text, | |
| corrected=corrected, | |
| num_chunks=num_chunks, | |
| processing_time_seconds=round(processing_time, 3), | |
| changes_made=changes_made | |
| ) | |
| except ValueError as e: | |
| # Input validation errors | |
| logger.warning(f"[{request_id}] Validation error: {str(e)}") | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid input. Please check your text and prompt." | |
| ) | |
| except MemoryError: | |
| logger.error(f"[{request_id}] Memory error during processing") | |
| raise HTTPException( | |
| status_code=413, | |
| detail="Text too complex to process. Try shorter text." | |
| ) | |
| except Exception as e: | |
| # Catch-all for unexpected errors - sanitize the error message | |
| logger.error(f"[{request_id}] Correction failed: {str(e)}", exc_info=True) | |
| raise HTTPException( | |
| status_code=500, | |
| detail="An internal error occurred. Please try again later." | |
| ) | |
| async def global_exception_handler(request, exc): | |
| """Global exception handler for unexpected errors""" | |
| logger.error(f"Unhandled exception: {str(exc)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Internal server error occurred"} | |
| ) | |
| if __name__ == "__main__": | |
| # Run server on port 8080 | |
| # For production, use reverse proxy (nginx) on ports 80/443 -> 8080 | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8080, | |
| log_level="info", | |
| access_log=True, | |
| ) | |
| import subprocess | |
| import time | |
| # Wait for server to start | |
| time.sleep(5) | |
| # Test if server responds | |
| try: | |
| result = subprocess.run( | |
| ["curl", "-f", "http://localhost:8080"], | |
| capture_output=True, | |
| timeout=5 | |
| ) | |
| print(f"Self-ping status: {result.returncode}") | |
| print(f"Output: {result.stdout.decode()}") | |
| except Exception as e: | |
| print(f"Self-ping failed: {e}") | |