from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator, ConfigDict from contextlib import asynccontextmanager from optimum.onnxruntime import ORTModelForSeq2SeqLM from transformers import AutoTokenizer from difflib import SequenceMatcher from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded import onnxruntime as ort import uvicorn import os import time import logging from typing import List, Tuple import re import uuid from datetime import datetime os.environ["CUDA_VISIBLE_DEVICES"] = "" # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Global model variables model = None tokenizer = None start_time = None # Track server uptime # Configuration VERSION = "1.0.0" MODEL_PATH = "IDK100boysaj/coedit-xl-onnx-8bit" MAX_INPUT_TOKENS = 512 # CoEdit-XL hard limit (including prompt) MAX_OUTPUT_TOKENS = 512 # Fixed at 512 as requested OVERLAP_SENTENCES = 2 # Fixed 2-sentence overlap for adaptive merging MAX_TEXT_LENGTH = 50000 # Character limit to prevent memory issues BATCH_SIZE = 2 # Conservative batch size for 2vCPU + 16GB RAM # Rate limiting configuration limiter = Limiter(key_func=get_remote_address) RATE_LIMIT = "100/minute" # 100 requests per minute per IP @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup, cleanup on shutdown""" global model, tokenizer, start_time start_time = datetime.now() logger.info(f"Starting CoEdit-XL API v{VERSION}...") logger.info(f"Loading CoEdit-XL quantized ONNX model from {MODEL_PATH}...") try: # Configure ONNX Runtime session options for 2vCPU optimization session_options = ort.SessionOptions() session_options.intra_op_num_threads = 2 # Match vCPU count session_options.inter_op_num_threads = 1 # Single thread for inter-op session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL # Sequential execution session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL # Disable thread spinning to save CPU cycles on wait session_options.add_session_config_entry("session.intra_op.allow_spinning", "0") session_options.add_session_config_entry("session.inter_op.allow_spinning", "0") logger.info("Configured ONNX session: 2 intra-op threads, no spinning, sequential execution") # Load quantized ONNX model with Optimum # Explicitly specify file names for non-standard naming model = ORTModelForSeq2SeqLM.from_pretrained( MODEL_PATH, provider="CPUExecutionProvider", # Force CPU execution session_options=session_options, # Add optimized session options encoder_file_name="encoder_model.onnx", decoder_file_name="decoder_model.onnx", decoder_with_past_file_name="decoder_with_past_model.onnx", ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH, ) logger.info("✓ Quantized ONNX model loaded successfully on CPU") logger.info(f"Optimizations: Sentence-based chunking | Fixed {OVERLAP_SENTENCES}-sentence overlap | Batch size: {BATCH_SIZE}") logger.info(f"Adaptive merging with 0.9 similarity threshold") logger.info(f"Max input tokens: {MAX_INPUT_TOKENS} | Max output tokens: {MAX_OUTPUT_TOKENS}") except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise RuntimeError(f"Model loading failed: {str(e)}") yield # Server runs here # Cleanup on shutdown logger.info("Shutting down...") model = None tokenizer = None app = FastAPI( title="CoEdit-XL Grammar Correction API", version=VERSION, lifespan=lifespan, description="High-performance grammar correction API using quantized CoEdit-XL with rolling window support", ) # Add rate limiting app.state.limiter = limiter @app.exception_handler(RateLimitExceeded) async def rate_limit_handler(request: Request, exc: RateLimitExceeded): """Custom rate limit exceeded handler""" return JSONResponse( status_code=429, content={ "detail": "Rate limit exceeded. Maximum 100 requests per minute.", "retry_after": "60 seconds" } ) # CORS middleware for cross-origin requests app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class CorrectionRequest(BaseModel): model_config = ConfigDict( json_schema_extra={ "example": { "text": "When I grow up, I start to understand what he said is quite right.", "prompt": "Fix grammatical errors in this sentence:" } } ) text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH, description="Text to correct") prompt: str = Field("Fix grammatical errors in this sentence:", description="Task prompt (e.g., 'Fix grammatical errors:', 'Make this text more formal:', 'Paraphrase this:')") @field_validator('text') @classmethod def text_not_empty(cls, v: str) -> str: if not v.strip(): raise ValueError('Text cannot be empty or only whitespace') return v @field_validator('prompt') @classmethod def prompt_not_empty(cls, v: str) -> str: if not v.strip(): raise ValueError('Prompt cannot be empty or only whitespace') return v.strip() class CorrectionResponse(BaseModel): original: str corrected: str num_chunks: int processing_time_seconds: float changes_made: bool model: str = "CoEdit-XL (Quantized ONNX)" def split_into_complete_sentences(text: str) -> List[str]: """ Split text into complete sentences with proper punctuation handling. Handles: - Sentence terminators: . ? ! (and standalone semicolons) - Abbreviations: Dr., Mr., Mrs., Ms., Ph.D., etc., i.e., e.g., vs., Inc., Ltd. - Semicolons in lists (doesn't split on these) - Preserves punctuation with each sentence Examples: "Dr. Smith likes apples; oranges; milk. He is happy!" → ["Dr. Smith likes apples; oranges; milk.", "He is happy!"] "I love it. They are great!" → ["I love it.", "They are great!"] """ # Common abbreviations that end with period (not sentence boundaries) abbreviations = [ 'Dr', 'Mr', 'Mrs', 'Ms', 'Prof', 'Sr', 'Jr', 'Ph.D', 'M.D', 'B.A', 'M.A', 'Ph.D.', 'M.D.', 'B.A.', 'M.A.', 'etc', 'vs', 'Inc', 'Ltd', 'Co', 'Corp', 'i.e', 'e.g', 'a.m', 'p.m', 'U.S', 'U.K', 'i.e.', 'e.g.', 'a.m.', 'p.m.', 'U.S.', 'U.K.' ] # Split on sentence boundaries: [.!?] followed by space and capital letter # But check it's not an abbreviation first sentences = [] current = [] words = text.split() for i, word in enumerate(words): current.append(word) # Check if this word ends with a sentence terminator if word and word[-1] in '.!?': # Check if it's an abbreviation is_abbrev = any(word.rstrip('.!?').endswith(abbr.rstrip('.')) for abbr in abbreviations) # If not abbreviation and next word starts with capital, it's a sentence boundary if not is_abbrev and i + 1 < len(words) and words[i + 1] and words[i + 1][0].isupper(): sentences.append(' '.join(current)) current = [] # Add remaining words if current: sentences.append(' '.join(current)) # Handle semicolons that act as sentence separators (not in lists) # A semicolon is a separator if followed by space and capital letter result = [] for sent in sentences: # Check if this sentence contains standalone semicolons (sentence separators) # Pattern: semicolon followed by space and capital letter, not in a list context # List context: has colons before semicolons (e.g., "buy: apples; oranges") if ';' in sent: # Check if it's a list (has colon before semicolons) colon_pos = sent.find(':') semicolon_positions = [i for i, c in enumerate(sent) if c == ';'] # If semicolons come after a colon, it's likely a list - don't split is_list = colon_pos >= 0 and any(sp > colon_pos for sp in semicolon_positions) if not is_list: # Split on semicolons followed by space and capital subsents = re.split(r';\s+(?=[A-Z])', sent) # Add semicolon back to each part except last for i, ss in enumerate(subsents[:-1]): result.append(ss + ';') if subsents[-1].strip(): result.append(subsents[-1]) else: result.append(sent) else: result.append(sent) # Clean up and filter return [s.strip() for s in result if s.strip()] def chunk_by_fixed_overlap(text: str, tokenizer, max_tokens: int = MAX_INPUT_TOKENS, overlap_sentences: int = OVERLAP_SENTENCES) -> Tuple[List[str], List[int]]: """ Split text into chunks with FIXED sentence overlap (not percentage). Returns: chunks: List of text chunks overlap_counts: List indicating how many sentences in each chunk overlap from previous chunk Example with overlap_sentences=2: Sentences: ["A.", "B.", "C.", "D.", "E."] Chunks: - Chunk 0: ["A.", "B.", "C."] → overlap_count=0 - Chunk 1: ["B.", "C.", "D."] → overlap_count=2 - Chunk 2: ["C.", "D.", "E."] → overlap_count=2 """ # Tokenize full text to count tokens tokens = tokenizer.encode(text, add_special_tokens=False) if len(tokens) <= max_tokens: return ([text], [0]) # Split into sentences first (better for CoEdit context) sentences = split_into_complete_sentences(text) if len(sentences) == 0: return ([text], [0]) chunks = [] overlap_counts = [] sentence_idx = 0 while sentence_idx < len(sentences): current_chunk_sentences = [] current_tokens = 0 chunk_start_idx = sentence_idx # Determine overlap count for this chunk if len(chunks) == 0: # First chunk has no overlap overlap_count = 0 else: # Subsequent chunks: go back by overlap_sentences overlap_count = min(overlap_sentences, sentence_idx) chunk_start_idx = sentence_idx - overlap_count # Build chunk starting from chunk_start_idx for i in range(chunk_start_idx, len(sentences)): sentence = sentences[i] sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False) sentence_token_count = len(sentence_tokens) # Check if adding this sentence would exceed token limit if current_tokens + sentence_token_count > max_tokens and current_chunk_sentences: # Chunk is full, stop here break current_chunk_sentences.append(sentence) current_tokens += sentence_token_count # If this is a new sentence (not overlap), advance the index if i >= sentence_idx: sentence_idx = i + 1 # Save the chunk if current_chunk_sentences: chunks.append(' '.join(current_chunk_sentences)) overlap_counts.append(overlap_count) else: # Edge case: single sentence too long, split by words long_sentence = sentences[chunk_start_idx] words = long_sentence.split() word_chunk = [] word_tokens = 0 for word in words: word_token_count = len(tokenizer.encode(word, add_special_tokens=False)) if word_tokens + word_token_count > max_tokens and word_chunk: chunks.append(' '.join(word_chunk)) overlap_counts.append(0) # Word-level chunks don't have sentence overlap word_chunk = [] word_tokens = 0 word_chunk.append(word) word_tokens += word_token_count if word_chunk: chunks.append(' '.join(word_chunk)) overlap_counts.append(0) sentence_idx += 1 return (chunks, overlap_counts) def merge_with_adaptive_overlap( corrected_chunks: List[str], overlap_counts: List[int], similarity_threshold: float = 0.9 ) -> str: """ Merge corrected chunks using adaptive overlap strategy. Compares overlapping sentences and: - If similarity >= threshold: Keep previous version (consistency) - If similarity < threshold: Use current version (better context) Args: corrected_chunks: List of corrected text chunks overlap_counts: Number of sentences in each chunk that overlap from previous similarity_threshold: Minimum similarity to keep previous version (default 0.9) Returns: Merged corrected text Example: Chunk 1: "I like apples. They are delicious." Chunk 2 (overlap=1): "They are delicious. My friend agrees." → If both have "They are delicious." (similarity=1.0), keep chunk 1's version → Add "My friend agrees." """ if len(corrected_chunks) == 0: return "" if len(corrected_chunks) == 1: return corrected_chunks[0] # Start with first chunk (no overlap) result_sentences = split_into_complete_sentences(corrected_chunks[0]) for chunk_idx in range(1, len(corrected_chunks)): current_chunk = corrected_chunks[chunk_idx] overlap_count = overlap_counts[chunk_idx] # Split current chunk into sentences current_sentences = split_into_complete_sentences(current_chunk) if overlap_count == 0: # No overlap, just append all sentences result_sentences.extend(current_sentences) continue # Compare overlapping sentences # Get last N sentences from result (where N = overlap_count) comparison_count = min(overlap_count, len(result_sentences)) previous_overlap_sents = result_sentences[-comparison_count:] if comparison_count > 0 else [] # Get first N sentences from current chunk current_overlap_sents = current_sentences[:min(overlap_count, len(current_sentences))] # Compare each overlapping sentence sentences_to_replace = [] for i in range(min(len(previous_overlap_sents), len(current_overlap_sents))): prev_sent = previous_overlap_sents[i] curr_sent = current_overlap_sents[i] # Calculate similarity similarity = SequenceMatcher(None, prev_sent.lower(), curr_sent.lower()).ratio() logger.debug(f"Comparing overlap {i+1}: similarity={similarity:.3f}") logger.debug(f" Previous: {prev_sent[:80]}...") logger.debug(f" Current: {curr_sent[:80]}...") if similarity < similarity_threshold: # Current version is significantly different, use it (better context) sentences_to_replace.append((i, curr_sent)) logger.debug(f" → Using current version (better context)") else: # Keep previous version (consistency) logger.debug(f" → Keeping previous version (consistent)") # Replace sentences if needed for idx, new_sent in sentences_to_replace: result_idx = len(result_sentences) - comparison_count + idx if 0 <= result_idx < len(result_sentences): result_sentences[result_idx] = new_sent # Add non-overlapping sentences from current chunk non_overlap_start = min(overlap_count, len(current_sentences)) if non_overlap_start < len(current_sentences): result_sentences.extend(current_sentences[non_overlap_start:]) return ' '.join(result_sentences) def correct_text( text: str, prompt: str = "Fix grammatical errors in this sentence:" ) -> Tuple[str, int]: """ Correct text using CoEdit-XL with rolling window approach. Returns: (corrected_text, num_chunks) CoEdit format: "{prompt} {text}" Example prompts: - "Fix grammatical errors in this sentence:" - "Make this text more formal:" - "Paraphrase this:" - "Improve the clarity of this text:" """ if not text.strip(): return text, 0 # Safety check - should never happen due to endpoint guard if tokenizer is None or model is None: raise RuntimeError("Model or tokenizer not initialized") # Ensure prompt ends with space or colon for proper formatting if not prompt.endswith((' ', ':')): prompt = prompt + ' ' elif prompt.endswith(':') and not prompt.endswith(': '): prompt = prompt + ' ' # Calculate max tokens for text content (reserve tokens for prompt) prompt_tokens = len(tokenizer.encode(prompt, add_special_tokens=False)) max_text_tokens = MAX_INPUT_TOKENS - prompt_tokens - 2 # -2 for special tokens if max_text_tokens < 100: logger.warning(f"Prompt is too long ({prompt_tokens} tokens), may affect chunking") max_text_tokens = 100 # Minimum text tokens # Split text into manageable chunks with FIXED sentence overlap original_chunks, overlap_counts = chunk_by_fixed_overlap(text, tokenizer, max_text_tokens, overlap_sentences=OVERLAP_SENTENCES) logger.info(f"Split text into {len(original_chunks)} chunks with fixed {OVERLAP_SENTENCES}-sentence overlap") logger.info(f"Overlap counts: {overlap_counts}") corrected_chunks = [] # Process chunks in batches for better performance batch_size = min(BATCH_SIZE, len(original_chunks)) for batch_start in range(0, len(original_chunks), batch_size): batch_end = min(batch_start + batch_size, len(original_chunks)) batch = original_chunks[batch_start:batch_end] logger.info(f"Processing batch {batch_start//batch_size + 1} (chunks {batch_start+1}-{batch_end}/{len(original_chunks)})") try: # Prepare inputs with user-provided prompt (CRITICAL: prompt repeats for each chunk!) prompted_batch = [f"{prompt}{chunk}" for chunk in batch] # Tokenize batch inputs = tokenizer( prompted_batch, return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKENS ) # Generate corrections with fixed max_length=512 outputs = model.generate( **inputs, max_length=MAX_OUTPUT_TOKENS, early_stopping=True, ) # Decode outputs (prompt is removed by skip_special_tokens=True) corrected_batch = tokenizer.batch_decode(outputs, skip_special_tokens=True) corrected_chunks.extend(corrected_batch) except Exception as e: logger.error(f"Error processing batch: {str(e)}") # Fallback to original chunks if correction fails corrected_chunks.extend(batch) # Merge corrected chunks with adaptive overlap strategy final_corrected = merge_with_adaptive_overlap(corrected_chunks, overlap_counts, similarity_threshold=0.9) return final_corrected, len(original_chunks) @app.get("/") async def root(): """Health check endpoint""" return { "status": "online", "service": "CoEdit-XL Grammar Correction API", "model": "CoEdit-XL (Quantized ONNX)", "model_path": MODEL_PATH, "max_input_tokens": MAX_INPUT_TOKENS, "max_output_tokens": MAX_OUTPUT_TOKENS, "overlap_sentences": OVERLAP_SENTENCES, "adaptive_merging": True, "similarity_threshold": 0.9, "max_text_length": MAX_TEXT_LENGTH, "example_prompts": [ "Fix grammatical errors in this sentence:", "Make this text more formal:", "Paraphrase this:", "Improve the clarity of this text:", "Neutralize this text:", "Make this easier to understand:" ], "endpoints": { "correction": "POST /correct", "health": "GET /health", } } @app.get("/health") async def health(): """Enhanced health check with uptime and system metrics""" try: import psutil uptime_seconds = (datetime.now() - start_time).total_seconds() if start_time else 0 cpu_percent = psutil.cpu_percent(interval=0.1) mem = psutil.virtual_memory() return { "status": "healthy", "version": VERSION, "uptime_seconds": round(uptime_seconds, 2), "uptime_human": str(datetime.now() - start_time).split('.')[0] if start_time else "N/A", "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None, "device": "CPU (ONNX Runtime)", "configuration": { "max_input_tokens": MAX_INPUT_TOKENS, "max_output_tokens": MAX_OUTPUT_TOKENS, "overlap_sentences": OVERLAP_SENTENCES, "max_text_length": MAX_TEXT_LENGTH, "batch_size": BATCH_SIZE, "rate_limit": RATE_LIMIT }, "system": { "cpu_usage_percent": cpu_percent, "memory_used_gb": round(mem.used / (1024**3), 2), "memory_available_gb": round(mem.available / (1024**3), 2), "memory_percent": mem.percent } } except ImportError: uptime_seconds = (datetime.now() - start_time).total_seconds() if start_time else 0 return { "status": "healthy", "version": VERSION, "uptime_seconds": round(uptime_seconds, 2), "uptime_human": str(datetime.now() - start_time).split('.')[0] if start_time else "N/A", "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None, "device": "CPU (ONNX Runtime)", "configuration": { "max_input_tokens": MAX_INPUT_TOKENS, "max_output_tokens": MAX_OUTPUT_TOKENS, "overlap_sentences": OVERLAP_SENTENCES, "max_text_length": MAX_TEXT_LENGTH, "batch_size": BATCH_SIZE, "rate_limit": RATE_LIMIT }, "note": "Install psutil for system metrics" } @app.post("/correct", response_model=CorrectionResponse) @limiter.limit(RATE_LIMIT) async def correct_grammar(request_obj: Request, request: CorrectionRequest): """ Correct/edit text using CoEdit-XL with custom prompts. Handles texts longer than 512 tokens using rolling window approach. Rate limited to 100 requests per minute per IP. Example prompts: - "Fix grammatical errors in this sentence:" - "Make this text more formal:" - "Paraphrase this:" - "Improve the clarity of this text:" """ request_id = str(uuid.uuid4())[:8] if model is None or tokenizer is None: logger.error(f"[{request_id}] Model not loaded") raise HTTPException( status_code=503, detail="Service temporarily unavailable. Model not loaded." ) # Validate text length (Pydantic already checks, but double-check for safety) if len(request.text) > MAX_TEXT_LENGTH: logger.warning(f"[{request_id}] Text too large: {len(request.text)} chars") raise HTTPException( status_code=413, detail=f"Text too large. Maximum allowed: {MAX_TEXT_LENGTH} characters." ) try: start_time_req = time.time() logger.info(f"[{request_id}] Correction request | Text: {len(request.text)} chars | Prompt: '{request.prompt[:50]}...'") corrected, num_chunks = correct_text( request.text, request.prompt ) processing_time = time.time() - start_time_req changes_made = corrected.strip() != request.text.strip() logger.info(f"[{request_id}] Completed in {processing_time:.2f}s | Chunks: {num_chunks} | Changes: {changes_made}") return CorrectionResponse( original=request.text, corrected=corrected, num_chunks=num_chunks, processing_time_seconds=round(processing_time, 3), changes_made=changes_made ) except ValueError as e: # Input validation errors logger.warning(f"[{request_id}] Validation error: {str(e)}") raise HTTPException( status_code=400, detail="Invalid input. Please check your text and prompt." ) except MemoryError: logger.error(f"[{request_id}] Memory error during processing") raise HTTPException( status_code=413, detail="Text too complex to process. Try shorter text." ) except Exception as e: # Catch-all for unexpected errors - sanitize the error message logger.error(f"[{request_id}] Correction failed: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail="An internal error occurred. Please try again later." ) @app.exception_handler(Exception) async def global_exception_handler(request, exc): """Global exception handler for unexpected errors""" logger.error(f"Unhandled exception: {str(exc)}", exc_info=True) return JSONResponse( status_code=500, content={"detail": "Internal server error occurred"} ) if __name__ == "__main__": # Run server on port 8080 # For production, use reverse proxy (nginx) on ports 80/443 -> 8080 uvicorn.run( app, host="0.0.0.0", port=8080, log_level="info", access_log=True, ) import subprocess import time # Wait for server to start time.sleep(5) # Test if server responds try: result = subprocess.run( ["curl", "-f", "http://localhost:8080"], capture_output=True, timeout=5 ) print(f"Self-ping status: {result.returncode}") print(f"Output: {result.stdout.decode()}") except Exception as e: print(f"Self-ping failed: {e}")