skanderovitch's picture
Update app.py
1c61dbb verified
import os
from fastapi import FastAPI, UploadFile, File, HTTPException, Header
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
import torch
import numpy as np
import io
# Set up the device for the model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the feature extraction pipeline
feature_extraction_pipeline = pipeline(
task="image-feature-extraction",
model="google/vit-base-patch16-384",
device=0 if torch.cuda.is_available() else -1
)
# Read SECRET_KEY from the environment
SECRET_KEY = os.getenv("SECRET_KEY", "default-secret")
# Initialize FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "App is running successfully!"}
# Endpoint to extract image features
@app.post("/extract-features/")
async def extract_features(file: UploadFile = File(...), secret_key: str = Header(None)):
try:
# Verify the SECRET_KEY
if secret_key != SECRET_KEY:
raise HTTPException(status_code=403, detail=f"Invalid SECRET_KEY {secret_key} vs {SECRET_KEY}")
# Validate file format
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Get feature embeddings
features = feature_extraction_pipeline(image) # Shape: (1, 577, 768)
# Extract CLS token embedding
cls_embedding = np.array(features)[0, 0, :] # Shape: (768,)
# Return the embedding vector
return JSONResponse(content={"features": cls_embedding.tolist()})
except HTTPException:
raise # Reraise HTTPExceptions for proper status codes
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)