Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, Header, Depends | |
from pydantic import BaseModel | |
from typing import Optional, List | |
from datetime import datetime | |
import torch | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
import time | |
#from dotenv import load_dotenv | |
import os | |
#load_dotenv() | |
app = FastAPI() | |
API_KEY = os.getenv("API_KEY") | |
# Configuration | |
API_KEYS = { | |
API_KEY : "user1" # In production, use a secure database | |
} | |
# Initialize model and tokenizer with smaller model for Spaces | |
MODEL_NAME = "tuner007/pegasus_paraphrase" | |
print("Loading model and tokenizer...") | |
tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache") | |
model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache") | |
device = "cpu" # Force CPU for Spaces deployment | |
model = model.to(device) | |
print("Model and tokenizer loaded successfully!") | |
class TextRequest(BaseModel): | |
text: str | |
style: Optional[str] = "standard" | |
num_variations: Optional[int] = 1 | |
class BatchRequest(BaseModel): | |
texts: List[str] | |
style: Optional[str] = "standard" | |
num_variations: Optional[int] = 1 | |
async def verify_api_key(api_key: str = Header(..., name="X-API-Key")): | |
if api_key not in API_KEYS: | |
raise HTTPException(status_code=403, detail="Invalid API key") | |
return api_key | |
def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]: | |
try: | |
# Get parameters based on style | |
params = { | |
"standard": {"temperature": 1.5, "top_k": 80}, | |
"formal": {"temperature": 1.0, "top_k": 50}, | |
"casual": {"temperature": 1.6, "top_k": 100}, | |
"creative": {"temperature": 2.8, "top_k": 170}, | |
}.get(style, {"temperature": 1.0, "top_k": 50}) | |
# Tokenize the input text | |
inputs = tokenizer(text, truncation=False, padding='longest', return_tensors="pt").to(device) | |
# Generate paraphrases | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_length=10000, | |
num_return_sequences=num_variations, | |
num_beams=10, | |
temperature=params["temperature"], | |
top_k=params["top_k"], | |
do_sample=True, | |
early_stopping=True, | |
) | |
# Decode the generated outputs | |
paraphrases = [ | |
tokenizer.decode(output, skip_special_tokens=True) | |
for output in outputs | |
] | |
return paraphrases | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}") | |
async def root(): | |
return {"message": "Paraphrase API is running. Use /docs for API documentation."} | |
async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)): | |
try: | |
start_time = time.time() | |
paraphrases = generate_paraphrase( | |
request.text, | |
request.style, | |
request.num_variations | |
) | |
processing_time = time.time() - start_time | |
return { | |
"status": "success", | |
"original_text": request.text, | |
"paraphrased_texts": paraphrases, | |
"style": request.style, | |
"processing_time": f"{processing_time:.2f} seconds", | |
"timestamp": datetime.now().isoformat() | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)): | |
try: | |
start_time = time.time() | |
results = [] | |
for text in request.texts: | |
paraphrases = generate_paraphrase( | |
text, | |
request.style, | |
request.num_variations | |
) | |
results.append({ | |
"original_text": text, | |
"paraphrased_texts": paraphrases, | |
"style": request.style | |
}) | |
processing_time = time.time() - start_time | |
return { | |
"status": "success", | |
"results": results, | |
"total_texts_processed": len(request.texts), | |
"processing_time": f"{processing_time:.2f} seconds", | |
"timestamp": datetime.now().isoformat() | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |