coeditxl / app.py
IDK100boysaj's picture
Update app.py
026aefc verified
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator, ConfigDict
from contextlib import asynccontextmanager
from optimum.onnxruntime import ORTModelForSeq2SeqLM
from transformers import AutoTokenizer
from difflib import SequenceMatcher
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import onnxruntime as ort
import uvicorn
import os
import time
import logging
from typing import List, Tuple
import re
import uuid
from datetime import datetime
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global model variables
model = None
tokenizer = None
start_time = None # Track server uptime
# Configuration
VERSION = "1.0.0"
MODEL_PATH = "IDK100boysaj/coedit-xl-onnx-8bit"
MAX_INPUT_TOKENS = 512 # CoEdit-XL hard limit (including prompt)
MAX_OUTPUT_TOKENS = 512 # Fixed at 512 as requested
OVERLAP_SENTENCES = 2 # Fixed 2-sentence overlap for adaptive merging
MAX_TEXT_LENGTH = 50000 # Character limit to prevent memory issues
BATCH_SIZE = 2 # Conservative batch size for 2vCPU + 16GB RAM
# Rate limiting configuration
limiter = Limiter(key_func=get_remote_address)
RATE_LIMIT = "100/minute" # 100 requests per minute per IP
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup, cleanup on shutdown"""
global model, tokenizer, start_time
start_time = datetime.now()
logger.info(f"Starting CoEdit-XL API v{VERSION}...")
logger.info(f"Loading CoEdit-XL quantized ONNX model from {MODEL_PATH}...")
try:
# Configure ONNX Runtime session options for 2vCPU optimization
session_options = ort.SessionOptions()
session_options.intra_op_num_threads = 2 # Match vCPU count
session_options.inter_op_num_threads = 1 # Single thread for inter-op
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL # Sequential execution
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Disable thread spinning to save CPU cycles on wait
session_options.add_session_config_entry("session.intra_op.allow_spinning", "0")
session_options.add_session_config_entry("session.inter_op.allow_spinning", "0")
logger.info("Configured ONNX session: 2 intra-op threads, no spinning, sequential execution")
# Load quantized ONNX model with Optimum
# Explicitly specify file names for non-standard naming
model = ORTModelForSeq2SeqLM.from_pretrained(
MODEL_PATH,
provider="CPUExecutionProvider", # Force CPU execution
session_options=session_options, # Add optimized session options
encoder_file_name="encoder_model.onnx",
decoder_file_name="decoder_model.onnx",
decoder_with_past_file_name="decoder_with_past_model.onnx",
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
)
logger.info("✓ Quantized ONNX model loaded successfully on CPU")
logger.info(f"Optimizations: Sentence-based chunking | Fixed {OVERLAP_SENTENCES}-sentence overlap | Batch size: {BATCH_SIZE}")
logger.info(f"Adaptive merging with 0.9 similarity threshold")
logger.info(f"Max input tokens: {MAX_INPUT_TOKENS} | Max output tokens: {MAX_OUTPUT_TOKENS}")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise RuntimeError(f"Model loading failed: {str(e)}")
yield # Server runs here
# Cleanup on shutdown
logger.info("Shutting down...")
model = None
tokenizer = None
app = FastAPI(
title="CoEdit-XL Grammar Correction API",
version=VERSION,
lifespan=lifespan,
description="High-performance grammar correction API using quantized CoEdit-XL with rolling window support",
)
# Add rate limiting
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
"""Custom rate limit exceeded handler"""
return JSONResponse(
status_code=429,
content={
"detail": "Rate limit exceeded. Maximum 100 requests per minute.",
"retry_after": "60 seconds"
}
)
# CORS middleware for cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class CorrectionRequest(BaseModel):
model_config = ConfigDict(
json_schema_extra={
"example": {
"text": "When I grow up, I start to understand what he said is quite right.",
"prompt": "Fix grammatical errors in this sentence:"
}
}
)
text: str = Field(..., min_length=1, max_length=MAX_TEXT_LENGTH, description="Text to correct")
prompt: str = Field("Fix grammatical errors in this sentence:", description="Task prompt (e.g., 'Fix grammatical errors:', 'Make this text more formal:', 'Paraphrase this:')")
@field_validator('text')
@classmethod
def text_not_empty(cls, v: str) -> str:
if not v.strip():
raise ValueError('Text cannot be empty or only whitespace')
return v
@field_validator('prompt')
@classmethod
def prompt_not_empty(cls, v: str) -> str:
if not v.strip():
raise ValueError('Prompt cannot be empty or only whitespace')
return v.strip()
class CorrectionResponse(BaseModel):
original: str
corrected: str
num_chunks: int
processing_time_seconds: float
changes_made: bool
model: str = "CoEdit-XL (Quantized ONNX)"
def split_into_complete_sentences(text: str) -> List[str]:
"""
Split text into complete sentences with proper punctuation handling.
Handles:
- Sentence terminators: . ? ! (and standalone semicolons)
- Abbreviations: Dr., Mr., Mrs., Ms., Ph.D., etc., i.e., e.g., vs., Inc., Ltd.
- Semicolons in lists (doesn't split on these)
- Preserves punctuation with each sentence
Examples:
"Dr. Smith likes apples; oranges; milk. He is happy!"
→ ["Dr. Smith likes apples; oranges; milk.", "He is happy!"]
"I love it. They are great!"
→ ["I love it.", "They are great!"]
"""
# Common abbreviations that end with period (not sentence boundaries)
abbreviations = [
'Dr', 'Mr', 'Mrs', 'Ms', 'Prof', 'Sr', 'Jr',
'Ph.D', 'M.D', 'B.A', 'M.A', 'Ph.D.', 'M.D.', 'B.A.', 'M.A.',
'etc', 'vs', 'Inc', 'Ltd', 'Co', 'Corp',
'i.e', 'e.g', 'a.m', 'p.m', 'U.S', 'U.K',
'i.e.', 'e.g.', 'a.m.', 'p.m.', 'U.S.', 'U.K.'
]
# Split on sentence boundaries: [.!?] followed by space and capital letter
# But check it's not an abbreviation first
sentences = []
current = []
words = text.split()
for i, word in enumerate(words):
current.append(word)
# Check if this word ends with a sentence terminator
if word and word[-1] in '.!?':
# Check if it's an abbreviation
is_abbrev = any(word.rstrip('.!?').endswith(abbr.rstrip('.')) for abbr in abbreviations)
# If not abbreviation and next word starts with capital, it's a sentence boundary
if not is_abbrev and i + 1 < len(words) and words[i + 1] and words[i + 1][0].isupper():
sentences.append(' '.join(current))
current = []
# Add remaining words
if current:
sentences.append(' '.join(current))
# Handle semicolons that act as sentence separators (not in lists)
# A semicolon is a separator if followed by space and capital letter
result = []
for sent in sentences:
# Check if this sentence contains standalone semicolons (sentence separators)
# Pattern: semicolon followed by space and capital letter, not in a list context
# List context: has colons before semicolons (e.g., "buy: apples; oranges")
if ';' in sent:
# Check if it's a list (has colon before semicolons)
colon_pos = sent.find(':')
semicolon_positions = [i for i, c in enumerate(sent) if c == ';']
# If semicolons come after a colon, it's likely a list - don't split
is_list = colon_pos >= 0 and any(sp > colon_pos for sp in semicolon_positions)
if not is_list:
# Split on semicolons followed by space and capital
subsents = re.split(r';\s+(?=[A-Z])', sent)
# Add semicolon back to each part except last
for i, ss in enumerate(subsents[:-1]):
result.append(ss + ';')
if subsents[-1].strip():
result.append(subsents[-1])
else:
result.append(sent)
else:
result.append(sent)
# Clean up and filter
return [s.strip() for s in result if s.strip()]
def chunk_by_fixed_overlap(text: str, tokenizer, max_tokens: int = MAX_INPUT_TOKENS, overlap_sentences: int = OVERLAP_SENTENCES) -> Tuple[List[str], List[int]]:
"""
Split text into chunks with FIXED sentence overlap (not percentage).
Returns:
chunks: List of text chunks
overlap_counts: List indicating how many sentences in each chunk overlap from previous chunk
Example with overlap_sentences=2:
Sentences: ["A.", "B.", "C.", "D.", "E."]
Chunks:
- Chunk 0: ["A.", "B.", "C."] → overlap_count=0
- Chunk 1: ["B.", "C.", "D."] → overlap_count=2
- Chunk 2: ["C.", "D.", "E."] → overlap_count=2
"""
# Tokenize full text to count tokens
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= max_tokens:
return ([text], [0])
# Split into sentences first (better for CoEdit context)
sentences = split_into_complete_sentences(text)
if len(sentences) == 0:
return ([text], [0])
chunks = []
overlap_counts = []
sentence_idx = 0
while sentence_idx < len(sentences):
current_chunk_sentences = []
current_tokens = 0
chunk_start_idx = sentence_idx
# Determine overlap count for this chunk
if len(chunks) == 0:
# First chunk has no overlap
overlap_count = 0
else:
# Subsequent chunks: go back by overlap_sentences
overlap_count = min(overlap_sentences, sentence_idx)
chunk_start_idx = sentence_idx - overlap_count
# Build chunk starting from chunk_start_idx
for i in range(chunk_start_idx, len(sentences)):
sentence = sentences[i]
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
sentence_token_count = len(sentence_tokens)
# Check if adding this sentence would exceed token limit
if current_tokens + sentence_token_count > max_tokens and current_chunk_sentences:
# Chunk is full, stop here
break
current_chunk_sentences.append(sentence)
current_tokens += sentence_token_count
# If this is a new sentence (not overlap), advance the index
if i >= sentence_idx:
sentence_idx = i + 1
# Save the chunk
if current_chunk_sentences:
chunks.append(' '.join(current_chunk_sentences))
overlap_counts.append(overlap_count)
else:
# Edge case: single sentence too long, split by words
long_sentence = sentences[chunk_start_idx]
words = long_sentence.split()
word_chunk = []
word_tokens = 0
for word in words:
word_token_count = len(tokenizer.encode(word, add_special_tokens=False))
if word_tokens + word_token_count > max_tokens and word_chunk:
chunks.append(' '.join(word_chunk))
overlap_counts.append(0) # Word-level chunks don't have sentence overlap
word_chunk = []
word_tokens = 0
word_chunk.append(word)
word_tokens += word_token_count
if word_chunk:
chunks.append(' '.join(word_chunk))
overlap_counts.append(0)
sentence_idx += 1
return (chunks, overlap_counts)
def merge_with_adaptive_overlap(
corrected_chunks: List[str],
overlap_counts: List[int],
similarity_threshold: float = 0.9
) -> str:
"""
Merge corrected chunks using adaptive overlap strategy.
Compares overlapping sentences and:
- If similarity >= threshold: Keep previous version (consistency)
- If similarity < threshold: Use current version (better context)
Args:
corrected_chunks: List of corrected text chunks
overlap_counts: Number of sentences in each chunk that overlap from previous
similarity_threshold: Minimum similarity to keep previous version (default 0.9)
Returns:
Merged corrected text
Example:
Chunk 1: "I like apples. They are delicious."
Chunk 2 (overlap=1): "They are delicious. My friend agrees."
→ If both have "They are delicious." (similarity=1.0), keep chunk 1's version
→ Add "My friend agrees."
"""
if len(corrected_chunks) == 0:
return ""
if len(corrected_chunks) == 1:
return corrected_chunks[0]
# Start with first chunk (no overlap)
result_sentences = split_into_complete_sentences(corrected_chunks[0])
for chunk_idx in range(1, len(corrected_chunks)):
current_chunk = corrected_chunks[chunk_idx]
overlap_count = overlap_counts[chunk_idx]
# Split current chunk into sentences
current_sentences = split_into_complete_sentences(current_chunk)
if overlap_count == 0:
# No overlap, just append all sentences
result_sentences.extend(current_sentences)
continue
# Compare overlapping sentences
# Get last N sentences from result (where N = overlap_count)
comparison_count = min(overlap_count, len(result_sentences))
previous_overlap_sents = result_sentences[-comparison_count:] if comparison_count > 0 else []
# Get first N sentences from current chunk
current_overlap_sents = current_sentences[:min(overlap_count, len(current_sentences))]
# Compare each overlapping sentence
sentences_to_replace = []
for i in range(min(len(previous_overlap_sents), len(current_overlap_sents))):
prev_sent = previous_overlap_sents[i]
curr_sent = current_overlap_sents[i]
# Calculate similarity
similarity = SequenceMatcher(None, prev_sent.lower(), curr_sent.lower()).ratio()
logger.debug(f"Comparing overlap {i+1}: similarity={similarity:.3f}")
logger.debug(f" Previous: {prev_sent[:80]}...")
logger.debug(f" Current: {curr_sent[:80]}...")
if similarity < similarity_threshold:
# Current version is significantly different, use it (better context)
sentences_to_replace.append((i, curr_sent))
logger.debug(f" → Using current version (better context)")
else:
# Keep previous version (consistency)
logger.debug(f" → Keeping previous version (consistent)")
# Replace sentences if needed
for idx, new_sent in sentences_to_replace:
result_idx = len(result_sentences) - comparison_count + idx
if 0 <= result_idx < len(result_sentences):
result_sentences[result_idx] = new_sent
# Add non-overlapping sentences from current chunk
non_overlap_start = min(overlap_count, len(current_sentences))
if non_overlap_start < len(current_sentences):
result_sentences.extend(current_sentences[non_overlap_start:])
return ' '.join(result_sentences)
def correct_text(
text: str,
prompt: str = "Fix grammatical errors in this sentence:"
) -> Tuple[str, int]:
"""
Correct text using CoEdit-XL with rolling window approach.
Returns: (corrected_text, num_chunks)
CoEdit format: "{prompt} {text}"
Example prompts:
- "Fix grammatical errors in this sentence:"
- "Make this text more formal:"
- "Paraphrase this:"
- "Improve the clarity of this text:"
"""
if not text.strip():
return text, 0
# Safety check - should never happen due to endpoint guard
if tokenizer is None or model is None:
raise RuntimeError("Model or tokenizer not initialized")
# Ensure prompt ends with space or colon for proper formatting
if not prompt.endswith((' ', ':')):
prompt = prompt + ' '
elif prompt.endswith(':') and not prompt.endswith(': '):
prompt = prompt + ' '
# Calculate max tokens for text content (reserve tokens for prompt)
prompt_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
max_text_tokens = MAX_INPUT_TOKENS - prompt_tokens - 2 # -2 for special tokens
if max_text_tokens < 100:
logger.warning(f"Prompt is too long ({prompt_tokens} tokens), may affect chunking")
max_text_tokens = 100 # Minimum text tokens
# Split text into manageable chunks with FIXED sentence overlap
original_chunks, overlap_counts = chunk_by_fixed_overlap(text, tokenizer, max_text_tokens, overlap_sentences=OVERLAP_SENTENCES)
logger.info(f"Split text into {len(original_chunks)} chunks with fixed {OVERLAP_SENTENCES}-sentence overlap")
logger.info(f"Overlap counts: {overlap_counts}")
corrected_chunks = []
# Process chunks in batches for better performance
batch_size = min(BATCH_SIZE, len(original_chunks))
for batch_start in range(0, len(original_chunks), batch_size):
batch_end = min(batch_start + batch_size, len(original_chunks))
batch = original_chunks[batch_start:batch_end]
logger.info(f"Processing batch {batch_start//batch_size + 1} (chunks {batch_start+1}-{batch_end}/{len(original_chunks)})")
try:
# Prepare inputs with user-provided prompt (CRITICAL: prompt repeats for each chunk!)
prompted_batch = [f"{prompt}{chunk}" for chunk in batch]
# Tokenize batch
inputs = tokenizer(
prompted_batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_INPUT_TOKENS
)
# Generate corrections with fixed max_length=512
outputs = model.generate(
**inputs,
max_length=MAX_OUTPUT_TOKENS,
early_stopping=True,
)
# Decode outputs (prompt is removed by skip_special_tokens=True)
corrected_batch = tokenizer.batch_decode(outputs, skip_special_tokens=True)
corrected_chunks.extend(corrected_batch)
except Exception as e:
logger.error(f"Error processing batch: {str(e)}")
# Fallback to original chunks if correction fails
corrected_chunks.extend(batch)
# Merge corrected chunks with adaptive overlap strategy
final_corrected = merge_with_adaptive_overlap(corrected_chunks, overlap_counts, similarity_threshold=0.9)
return final_corrected, len(original_chunks)
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"status": "online",
"service": "CoEdit-XL Grammar Correction API",
"model": "CoEdit-XL (Quantized ONNX)",
"model_path": MODEL_PATH,
"max_input_tokens": MAX_INPUT_TOKENS,
"max_output_tokens": MAX_OUTPUT_TOKENS,
"overlap_sentences": OVERLAP_SENTENCES,
"adaptive_merging": True,
"similarity_threshold": 0.9,
"max_text_length": MAX_TEXT_LENGTH,
"example_prompts": [
"Fix grammatical errors in this sentence:",
"Make this text more formal:",
"Paraphrase this:",
"Improve the clarity of this text:",
"Neutralize this text:",
"Make this easier to understand:"
],
"endpoints": {
"correction": "POST /correct",
"health": "GET /health",
}
}
@app.get("/health")
async def health():
"""Enhanced health check with uptime and system metrics"""
try:
import psutil
uptime_seconds = (datetime.now() - start_time).total_seconds() if start_time else 0
cpu_percent = psutil.cpu_percent(interval=0.1)
mem = psutil.virtual_memory()
return {
"status": "healthy",
"version": VERSION,
"uptime_seconds": round(uptime_seconds, 2),
"uptime_human": str(datetime.now() - start_time).split('.')[0] if start_time else "N/A",
"model_loaded": model is not None,
"tokenizer_loaded": tokenizer is not None,
"device": "CPU (ONNX Runtime)",
"configuration": {
"max_input_tokens": MAX_INPUT_TOKENS,
"max_output_tokens": MAX_OUTPUT_TOKENS,
"overlap_sentences": OVERLAP_SENTENCES,
"max_text_length": MAX_TEXT_LENGTH,
"batch_size": BATCH_SIZE,
"rate_limit": RATE_LIMIT
},
"system": {
"cpu_usage_percent": cpu_percent,
"memory_used_gb": round(mem.used / (1024**3), 2),
"memory_available_gb": round(mem.available / (1024**3), 2),
"memory_percent": mem.percent
}
}
except ImportError:
uptime_seconds = (datetime.now() - start_time).total_seconds() if start_time else 0
return {
"status": "healthy",
"version": VERSION,
"uptime_seconds": round(uptime_seconds, 2),
"uptime_human": str(datetime.now() - start_time).split('.')[0] if start_time else "N/A",
"model_loaded": model is not None,
"tokenizer_loaded": tokenizer is not None,
"device": "CPU (ONNX Runtime)",
"configuration": {
"max_input_tokens": MAX_INPUT_TOKENS,
"max_output_tokens": MAX_OUTPUT_TOKENS,
"overlap_sentences": OVERLAP_SENTENCES,
"max_text_length": MAX_TEXT_LENGTH,
"batch_size": BATCH_SIZE,
"rate_limit": RATE_LIMIT
},
"note": "Install psutil for system metrics"
}
@app.post("/correct", response_model=CorrectionResponse)
@limiter.limit(RATE_LIMIT)
async def correct_grammar(request_obj: Request, request: CorrectionRequest):
"""
Correct/edit text using CoEdit-XL with custom prompts.
Handles texts longer than 512 tokens using rolling window approach.
Rate limited to 100 requests per minute per IP.
Example prompts:
- "Fix grammatical errors in this sentence:"
- "Make this text more formal:"
- "Paraphrase this:"
- "Improve the clarity of this text:"
"""
request_id = str(uuid.uuid4())[:8]
if model is None or tokenizer is None:
logger.error(f"[{request_id}] Model not loaded")
raise HTTPException(
status_code=503,
detail="Service temporarily unavailable. Model not loaded."
)
# Validate text length (Pydantic already checks, but double-check for safety)
if len(request.text) > MAX_TEXT_LENGTH:
logger.warning(f"[{request_id}] Text too large: {len(request.text)} chars")
raise HTTPException(
status_code=413,
detail=f"Text too large. Maximum allowed: {MAX_TEXT_LENGTH} characters."
)
try:
start_time_req = time.time()
logger.info(f"[{request_id}] Correction request | Text: {len(request.text)} chars | Prompt: '{request.prompt[:50]}...'")
corrected, num_chunks = correct_text(
request.text,
request.prompt
)
processing_time = time.time() - start_time_req
changes_made = corrected.strip() != request.text.strip()
logger.info(f"[{request_id}] Completed in {processing_time:.2f}s | Chunks: {num_chunks} | Changes: {changes_made}")
return CorrectionResponse(
original=request.text,
corrected=corrected,
num_chunks=num_chunks,
processing_time_seconds=round(processing_time, 3),
changes_made=changes_made
)
except ValueError as e:
# Input validation errors
logger.warning(f"[{request_id}] Validation error: {str(e)}")
raise HTTPException(
status_code=400,
detail="Invalid input. Please check your text and prompt."
)
except MemoryError:
logger.error(f"[{request_id}] Memory error during processing")
raise HTTPException(
status_code=413,
detail="Text too complex to process. Try shorter text."
)
except Exception as e:
# Catch-all for unexpected errors - sanitize the error message
logger.error(f"[{request_id}] Correction failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail="An internal error occurred. Please try again later."
)
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""Global exception handler for unexpected errors"""
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content={"detail": "Internal server error occurred"}
)
if __name__ == "__main__":
# Run server on port 8080
# For production, use reverse proxy (nginx) on ports 80/443 -> 8080
uvicorn.run(
app,
host="0.0.0.0",
port=8080,
log_level="info",
access_log=True,
)
import subprocess
import time
# Wait for server to start
time.sleep(5)
# Test if server responds
try:
result = subprocess.run(
["curl", "-f", "http://localhost:8080"],
capture_output=True,
timeout=5
)
print(f"Self-ping status: {result.returncode}")
print(f"Output: {result.stdout.decode()}")
except Exception as e:
print(f"Self-ping failed: {e}")