Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,59 +3,72 @@ import fastapi
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
|
6 |
-
class TorchTensor(torch.Tensor):
|
7 |
-
pass
|
8 |
-
|
9 |
-
class Prediction:
|
10 |
-
prediction: TorchTensor
|
11 |
-
|
12 |
app = fastapi.FastAPI(docs_url="/")
|
13 |
-
from transformers import ViTForImageClassification
|
14 |
-
|
15 |
-
# Define the number of classes in your custom dataset
|
16 |
-
num_classes = 20
|
17 |
|
18 |
-
#
|
19 |
model = ViTForImageClassification.from_pretrained(
|
20 |
'google/vit-base-patch16-224-in21k',
|
21 |
num_labels=num_classes # Specify the number of classes
|
22 |
)
|
23 |
-
|
24 |
-
|
25 |
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
|
|
|
26 |
# Define a function to preprocess the input image
|
27 |
def preprocess_input(input: fastapi.UploadFile):
|
28 |
image = Image.open(input.file)
|
29 |
image = image.resize((224, 224)).convert("RGB")
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
return
|
35 |
|
36 |
-
# Define an endpoint to make predictions
|
37 |
@app.post("/predict")
|
38 |
-
async def predict_endpoint(input:fastapi.UploadFile):
|
39 |
"""Make a prediction on an image uploaded by the user."""
|
40 |
|
41 |
# Preprocess the input image
|
42 |
-
|
43 |
|
44 |
# Make a prediction
|
45 |
-
prediction = model(
|
46 |
-
|
47 |
|
48 |
logits = prediction.logits
|
49 |
num_top_predictions = 3
|
50 |
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
|
51 |
-
|
52 |
-
# Get the top 3 class indices and their probabilities
|
53 |
top_indices = top_predictions.indices.squeeze().tolist()
|
54 |
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
|
55 |
-
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
app = fastapi.FastAPI(docs_url="/")
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
# Load your pre-trained model and other necessary components here
|
9 |
model = ViTForImageClassification.from_pretrained(
|
10 |
'google/vit-base-patch16-224-in21k',
|
11 |
num_labels=num_classes # Specify the number of classes
|
12 |
)
|
|
|
|
|
13 |
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
|
14 |
+
|
15 |
# Define a function to preprocess the input image
|
16 |
def preprocess_input(input: fastapi.UploadFile):
|
17 |
image = Image.open(input.file)
|
18 |
image = image.resize((224, 224)).convert("RGB")
|
19 |
+
input_data = np.array(image)
|
20 |
+
input_data = np.transpose(input_data, (2, 0, 1))
|
21 |
+
input_data = torch.from_numpy(input_data).float()
|
22 |
+
input_data = input_data.unsqueeze(0)
|
23 |
+
return input_data
|
24 |
|
|
|
25 |
@app.post("/predict")
|
26 |
+
async def predict_endpoint(input: fastapi.UploadFile):
|
27 |
"""Make a prediction on an image uploaded by the user."""
|
28 |
|
29 |
# Preprocess the input image
|
30 |
+
input_data = preprocess_input(input)
|
31 |
|
32 |
# Make a prediction
|
33 |
+
prediction = model(input_data)
|
|
|
34 |
|
35 |
logits = prediction.logits
|
36 |
num_top_predictions = 3
|
37 |
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
|
|
|
|
|
38 |
top_indices = top_predictions.indices.squeeze().tolist()
|
39 |
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
|
40 |
+
|
41 |
+
# Define class names for your dataset (modify as needed)
|
42 |
+
class_names = [
|
43 |
+
"Acral Lick Dermatitis",
|
44 |
+
"Acute moist dermatitis",
|
45 |
+
"Canine atopic dermatitis",
|
46 |
+
"Cherry Eye",
|
47 |
+
"Ear infections",
|
48 |
+
"External Parasites",
|
49 |
+
"Folliculitis",
|
50 |
+
"Healthy",
|
51 |
+
"Leishmaniasis",
|
52 |
+
"Lupus",
|
53 |
+
"Nuclear sclerosis",
|
54 |
+
"Otitis externa",
|
55 |
+
"Pruritus",
|
56 |
+
"Pyoderma",
|
57 |
+
"Rabies",
|
58 |
+
"Ringworm",
|
59 |
+
"Sarcoptic Mange",
|
60 |
+
"Sebaceous adenitis",
|
61 |
+
"Seborrhea",
|
62 |
+
"Skin tumor"
|
63 |
+
]
|
64 |
+
|
65 |
+
# Return the top N class names and their probabilities in JSON format
|
66 |
+
response_data = [
|
67 |
+
{
|
68 |
+
"class_name": class_names[idx],
|
69 |
+
"probability": prob
|
70 |
+
}
|
71 |
+
for idx, prob in zip(top_indices, top_probabilities)
|
72 |
+
]
|
73 |
+
|
74 |
+
return {"predictions": response_data}
|