alexfremont commited on
Commit
684956c
·
1 Parent(s): c763267

add new route

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. main.py +48 -0
  3. requirements.txt +2 -1
Dockerfile CHANGED
@@ -30,4 +30,4 @@ EXPOSE 7860
30
  # git clone $(cat /run/secrets/api_read)
31
 
32
  # Commande pour lancer l'application
33
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "4"]
 
30
  # git clone $(cat /run/secrets/api_read)
31
 
32
  # Commande pour lancer l'application
33
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py CHANGED
@@ -12,6 +12,9 @@ from huggingface_hub import hf_hub_download
12
  from architecture.resnet import ResNet
13
  import torch
14
  import logging
 
 
 
15
 
16
  app = FastAPI()
17
 
@@ -116,3 +119,48 @@ async def predict(request: PredictRequest):
116
  logging.info("confidence: %s", confidence)
117
  # Return the probabilities as JSON
118
  return JSONResponse(content={"confidence": confidence})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from architecture.resnet import ResNet
13
  import torch
14
  import logging
15
+ from typing import List
16
+ import httpx
17
+
18
 
19
  app = FastAPI()
20
 
 
119
  logging.info("confidence: %s", confidence)
120
  # Return the probabilities as JSON
121
  return JSONResponse(content={"confidence": confidence})
122
+
123
+
124
+ class BatchPredictRequest(BaseModel):
125
+ imageUrls: List[str]
126
+ modelName: str
127
+
128
+
129
+ @app.post("/batch_predict")
130
+ async def batch_predict(request: BatchPredictRequest):
131
+ model_name = request.modelName
132
+ results = []
133
+
134
+ # Verify if the model is loaded
135
+ if model_name not in model_pipelines:
136
+ raise HTTPException(status_code=404, detail="Model not found")
137
+
138
+ model = model_pipelines[model_name]
139
+
140
+ # Asynchronously process each image
141
+ async with httpx.AsyncClient() as client:
142
+ for image_url in request.imageUrls:
143
+ try:
144
+ response = await client.get(image_url)
145
+ image = Image.open(BytesIO(response.content))
146
+ except Exception as e:
147
+ results.append({"imageUrl": image_url, "error": "Invalid image URL"})
148
+ continue
149
+
150
+ # Preprocess the image
151
+ processed_image = process_image(image, size=image_size)
152
+
153
+ # Convert to tensor
154
+ image_tensor = transforms.ToTensor()(processed_image).unsqueeze(0)
155
+
156
+ # Perform inference
157
+ with torch.no_grad():
158
+ outputs = model(image_tensor)
159
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
160
+ predicted_probabilities = probabilities.numpy().tolist()
161
+ confidence = round(predicted_probabilities[0][1], 2)
162
+
163
+ results.append({"imageUrl": image_url, "confidence": confidence})
164
+
165
+ # Return the results as JSON
166
+ return JSONResponse(content={"results": results})
requirements.txt CHANGED
@@ -6,4 +6,5 @@ requests
6
  torchvision
7
  huggingface_hub
8
  torch
9
- numpy
 
 
6
  torchvision
7
  huggingface_hub
8
  torch
9
+ numpy
10
+ httpx