skanderovitch commited on
Commit
78a9e80
1 Parent(s): 201c97f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -1,8 +1,10 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
 
2
  from fastapi.responses import JSONResponse
3
  from transformers import pipeline
4
  from PIL import Image
5
  import torch
 
6
  import io
7
 
8
  # Set up the device for the model
@@ -15,6 +17,9 @@ feature_extraction_pipeline = pipeline(
15
  device=0 if torch.cuda.is_available() else -1
16
  )
17
 
 
 
 
18
  # Initialize FastAPI
19
  app = FastAPI()
20
 
@@ -24,8 +29,12 @@ def read_root():
24
 
25
  # Endpoint to extract image features
26
  @app.post("/extract-features/")
27
- async def extract_features(file: UploadFile = File(...)):
28
  try:
 
 
 
 
29
  # Validate file format
30
  if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
31
  raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
@@ -43,6 +52,8 @@ async def extract_features(file: UploadFile = File(...)):
43
  # Return the embedding vector
44
  return JSONResponse(content={"features": cls_embedding.tolist()})
45
 
 
 
46
  except Exception as e:
47
  raise HTTPException(status_code=500, detail=str(e))
48
 
 
1
+ import os
2
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Header
3
  from fastapi.responses import JSONResponse
4
  from transformers import pipeline
5
  from PIL import Image
6
  import torch
7
+ import numpy as np
8
  import io
9
 
10
  # Set up the device for the model
 
17
  device=0 if torch.cuda.is_available() else -1
18
  )
19
 
20
+ # Read SECRET_KEY from the environment
21
+ SECRET_KEY = os.getenv("SECRET_KEY", "default-secret")
22
+
23
  # Initialize FastAPI
24
  app = FastAPI()
25
 
 
29
 
30
  # Endpoint to extract image features
31
  @app.post("/extract-features/")
32
+ async def extract_features(file: UploadFile = File(...), secret_key: str = Header(None)):
33
  try:
34
+ # Verify the SECRET_KEY
35
+ if secret_key != SECRET_KEY:
36
+ raise HTTPException(status_code=403, detail="Invalid SECRET_KEY")
37
+
38
  # Validate file format
39
  if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
40
  raise HTTPException(status_code=400, detail=f"Unsupported file format. Upload a JPEG or PNG image. Received {file.content_type}")
 
52
  # Return the embedding vector
53
  return JSONResponse(content={"features": cls_embedding.tolist()})
54
 
55
+ except HTTPException:
56
+ raise # Reraise HTTPExceptions for proper status codes
57
  except Exception as e:
58
  raise HTTPException(status_code=500, detail=str(e))
59