ptdevtest / app.py
lahiruchamika27's picture
Update app.py
99838a4 verified
from fastapi import FastAPI, HTTPException, Header, Depends
from pydantic import BaseModel
from typing import Optional, List
from datetime import datetime
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import time
#from dotenv import load_dotenv
import os
#load_dotenv()
app = FastAPI()
API_KEY = os.getenv("API_KEY")
# Configuration
API_KEYS = {
API_KEY : "user1" # In production, use a secure database
}
# Initialize model and tokenizer with smaller model for Spaces
MODEL_NAME = "tuner007/pegasus_paraphrase"
print("Loading model and tokenizer...")
tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
device = "cpu" # Force CPU for Spaces deployment
model = model.to(device)
print("Model and tokenizer loaded successfully!")
class TextRequest(BaseModel):
text: str
style: Optional[str] = "standard"
num_variations: Optional[int] = 1
class BatchRequest(BaseModel):
texts: List[str]
style: Optional[str] = "standard"
num_variations: Optional[int] = 1
async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
if api_key not in API_KEYS:
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
try:
# Get parameters based on style
params = {
"standard": {"temperature": 1.5, "top_k": 80},
"formal": {"temperature": 1.0, "top_k": 50},
"casual": {"temperature": 1.6, "top_k": 100},
"creative": {"temperature": 2.8, "top_k": 170},
}.get(style, {"temperature": 1.0, "top_k": 50})
# Tokenize the input text
inputs = tokenizer(text, truncation=False, padding='longest', return_tensors="pt").to(device)
# Generate paraphrases
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=10000,
num_return_sequences=num_variations,
num_beams=10,
temperature=params["temperature"],
top_k=params["top_k"],
do_sample=True,
early_stopping=True,
)
# Decode the generated outputs
paraphrases = [
tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
return paraphrases
except Exception as e:
raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
@app.get("/")
async def root():
return {"message": "Paraphrase API is running. Use /docs for API documentation."}
@app.post("/api/paraphrase")
async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
try:
start_time = time.time()
paraphrases = generate_paraphrase(
request.text,
request.style,
request.num_variations
)
processing_time = time.time() - start_time
return {
"status": "success",
"original_text": request.text,
"paraphrased_texts": paraphrases,
"style": request.style,
"processing_time": f"{processing_time:.2f} seconds",
"timestamp": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-paraphrase")
async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
try:
start_time = time.time()
results = []
for text in request.texts:
paraphrases = generate_paraphrase(
text,
request.style,
request.num_variations
)
results.append({
"original_text": text,
"paraphrased_texts": paraphrases,
"style": request.style
})
processing_time = time.time() - start_time
return {
"status": "success",
"results": results,
"total_texts_processed": len(request.texts),
"processing_time": f"{processing_time:.2f} seconds",
"timestamp": datetime.now().isoformat()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))