|
import os |
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Header |
|
from fastapi.responses import JSONResponse |
|
from transformers import pipeline |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
import io |
|
|
|
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
feature_extraction_pipeline = pipeline( |
|
task="image-feature-extraction", |
|
model="google/vit-base-patch16-384", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
|
|
SECRET_KEY = os.getenv("SECRET_KEY", "default-secret") |
|
|
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"message": "App is running successfully!"} |
|
|
|
|
|
@app.post("/extract-features/") |
|
async def extract_features(file: UploadFile = File(...), secret_key: str = Header(None)): |
|
try: |
|
|
|
if secret_key != SECRET_KEY: |
|
raise HTTPException(status_code=403, detail=f"Invalid SECRET_KEY {secret_key} vs {SECRET_KEY}") |
|
|
|
|
|
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]: |
|
raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}") |
|
|
|
|
|
contents = await file.read() |
|
image = Image.open(io.BytesIO(contents)).convert("RGB") |
|
|
|
|
|
features = feature_extraction_pipeline(image) |
|
|
|
|
|
cls_embedding = np.array(features)[0, 0, :] |
|
|
|
|
|
return JSONResponse(content={"features": cls_embedding.tolist()}) |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|