# Simple implementation for translation using the BART model from fastapi import FastAPI from pydantic import BaseModel from transformers import BartTokenizer, BartForConditionalGeneration app = FastAPI() # Define request model class TranslationRequest(BaseModel): text: str max_length: int = 150 min_length: int = 40 # Download and cache the model during initialization # This happens only once when the app starts try: # Explicitly download to a specific directory with proper error handling cache_dir = "./model_cache" model_name = "facebook/bart-large-cnn" print(f"Loading tokenizer from {model_name}...") tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False) print(f"Loading model from {model_name}...") model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False) print("Model and tokenizer loaded successfully!") except Exception as e: print(f"Error loading model: {str(e)}") raise @app.post("/summarize/") async def translate_text(request: TranslationRequest): # Process the input text inputs = tokenizer(request.text, return_tensors="pt", max_length=1024, truncation=True) # Generate summary summary_ids = model.generate( inputs["input_ids"], max_length=request.max_length, min_length=request.min_length, num_beams=4, length_penalty=2.0, early_stopping=True ) # Decode the generated summary translation = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return {"summary": translation} # Basic health check endpoint @app.get("/health") async def health_check(): return {"status": "healthy", "model": "facebook/bart-large-cnn"}