from fastapi import FastAPI, HTTPException, Header, Depends from pydantic import BaseModel from typing import Optional, List from datetime import datetime import torch from transformers import PegasusForConditionalGeneration, PegasusTokenizer import time #from dotenv import load_dotenv import os #load_dotenv() app = FastAPI() API_KEY = os.getenv("API_KEY") # Configuration API_KEYS = { API_KEY : "user1" # In production, use a secure database } # Initialize model and tokenizer with smaller model for Spaces MODEL_NAME = "tuner007/pegasus_paraphrase" print("Loading model and tokenizer...") tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache") model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache") device = "cpu" # Force CPU for Spaces deployment model = model.to(device) print("Model and tokenizer loaded successfully!") class TextRequest(BaseModel): text: str style: Optional[str] = "standard" num_variations: Optional[int] = 1 class BatchRequest(BaseModel): texts: List[str] style: Optional[str] = "standard" num_variations: Optional[int] = 1 async def verify_api_key(api_key: str = Header(..., name="X-API-Key")): if api_key not in API_KEYS: raise HTTPException(status_code=403, detail="Invalid API key") return api_key def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]: try: # Get parameters based on style params = { "standard": {"temperature": 1.5, "top_k": 80}, "formal": {"temperature": 1.0, "top_k": 50}, "casual": {"temperature": 1.6, "top_k": 100}, "creative": {"temperature": 2.8, "top_k": 170}, }.get(style, {"temperature": 1.0, "top_k": 50}) # Tokenize the input text inputs = tokenizer(text, truncation=False, padding='longest', return_tensors="pt").to(device) # Generate paraphrases with torch.no_grad(): outputs = model.generate( **inputs, max_length=10000, num_return_sequences=num_variations, num_beams=10, temperature=params["temperature"], top_k=params["top_k"], do_sample=True, early_stopping=True, ) # Decode the generated outputs paraphrases = [ tokenizer.decode(output, skip_special_tokens=True) for output in outputs ] return paraphrases except Exception as e: raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}") @app.get("/") async def root(): return {"message": "Paraphrase API is running. Use /docs for API documentation."} @app.post("/api/paraphrase") async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)): try: start_time = time.time() paraphrases = generate_paraphrase( request.text, request.style, request.num_variations ) processing_time = time.time() - start_time return { "status": "success", "original_text": request.text, "paraphrased_texts": paraphrases, "style": request.style, "processing_time": f"{processing_time:.2f} seconds", "timestamp": datetime.now().isoformat() } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/batch-paraphrase") async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)): try: start_time = time.time() results = [] for text in request.texts: paraphrases = generate_paraphrase( text, request.style, request.num_variations ) results.append({ "original_text": text, "paraphrased_texts": paraphrases, "style": request.style }) processing_time = time.time() - start_time return { "status": "success", "results": results, "total_texts_processed": len(request.texts), "processing_time": f"{processing_time:.2f} seconds", "timestamp": datetime.now().isoformat() } except Exception as e: raise HTTPException(status_code=500, detail=str(e))