|
|
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from transformers import BartTokenizer, BartForConditionalGeneration |
|
|
|
app = FastAPI() |
|
|
|
|
|
class TranslationRequest(BaseModel): |
|
text: str |
|
max_length: int = 150 |
|
min_length: int = 40 |
|
|
|
|
|
|
|
try: |
|
|
|
cache_dir = "./model_cache" |
|
model_name = "facebook/bart-large-cnn" |
|
|
|
print(f"Loading tokenizer from {model_name}...") |
|
tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False) |
|
|
|
print(f"Loading model from {model_name}...") |
|
model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False) |
|
|
|
print("Model and tokenizer loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
raise |
|
|
|
@app.post("/summarize/") |
|
async def translate_text(request: TranslationRequest): |
|
|
|
inputs = tokenizer(request.text, return_tensors="pt", max_length=1024, truncation=True) |
|
|
|
|
|
summary_ids = model.generate( |
|
inputs["input_ids"], |
|
max_length=request.max_length, |
|
min_length=request.min_length, |
|
num_beams=4, |
|
length_penalty=2.0, |
|
early_stopping=True |
|
) |
|
|
|
|
|
translation = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
return {"summary": translation} |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy", "model": "facebook/bart-large-cnn"} |