SANJAYV10 commited on
Commit
719a218
1 Parent(s): 9cd2acf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -32
app.py CHANGED
@@ -3,59 +3,72 @@ import fastapi
3
  import numpy as np
4
  from PIL import Image
5
 
6
- class TorchTensor(torch.Tensor):
7
- pass
8
-
9
- class Prediction:
10
- prediction: TorchTensor
11
-
12
  app = fastapi.FastAPI(docs_url="/")
13
- from transformers import ViTForImageClassification
14
-
15
- # Define the number of classes in your custom dataset
16
- num_classes = 20
17
 
18
- # Initialize the ViTForImageClassification model
19
  model = ViTForImageClassification.from_pretrained(
20
  'google/vit-base-patch16-224-in21k',
21
  num_labels=num_classes # Specify the number of classes
22
  )
23
-
24
-
25
  model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
 
26
  # Define a function to preprocess the input image
27
  def preprocess_input(input: fastapi.UploadFile):
28
  image = Image.open(input.file)
29
  image = image.resize((224, 224)).convert("RGB")
30
- input = np.array(image)
31
- input = np.transpose(input, (2, 0, 1))
32
- input = torch.from_numpy(input).float()
33
- input = input.unsqueeze(0)
34
- return input
35
 
36
- # Define an endpoint to make predictions
37
  @app.post("/predict")
38
- async def predict_endpoint(input:fastapi.UploadFile):
39
  """Make a prediction on an image uploaded by the user."""
40
 
41
  # Preprocess the input image
42
- input = preprocess_input(input)
43
 
44
  # Make a prediction
45
- prediction = model(input)
46
-
47
 
48
  logits = prediction.logits
49
  num_top_predictions = 3
50
  top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
51
-
52
- # Get the top 3 class indices and their probabilities
53
  top_indices = top_predictions.indices.squeeze().tolist()
54
  top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
55
-
56
- # Get the disease names for the top 3 predictions
57
- disease_names = [disease_names[idx] for idx in top_indices]
58
-
59
- # Return the top 3 disease names and their probabilities in JSON format
60
- response_data = [{"disease_name": name, "probability": prob} for name, prob in zip(disease_names, top_probabilities)]
61
- return {"predictions": response_data}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  from PIL import Image
5
 
 
 
 
 
 
 
6
  app = fastapi.FastAPI(docs_url="/")
 
 
 
 
7
 
8
+ # Load your pre-trained model and other necessary components here
9
  model = ViTForImageClassification.from_pretrained(
10
  'google/vit-base-patch16-224-in21k',
11
  num_labels=num_classes # Specify the number of classes
12
  )
 
 
13
  model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
14
+
15
  # Define a function to preprocess the input image
16
  def preprocess_input(input: fastapi.UploadFile):
17
  image = Image.open(input.file)
18
  image = image.resize((224, 224)).convert("RGB")
19
+ input_data = np.array(image)
20
+ input_data = np.transpose(input_data, (2, 0, 1))
21
+ input_data = torch.from_numpy(input_data).float()
22
+ input_data = input_data.unsqueeze(0)
23
+ return input_data
24
 
 
25
  @app.post("/predict")
26
+ async def predict_endpoint(input: fastapi.UploadFile):
27
  """Make a prediction on an image uploaded by the user."""
28
 
29
  # Preprocess the input image
30
+ input_data = preprocess_input(input)
31
 
32
  # Make a prediction
33
+ prediction = model(input_data)
 
34
 
35
  logits = prediction.logits
36
  num_top_predictions = 3
37
  top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
 
 
38
  top_indices = top_predictions.indices.squeeze().tolist()
39
  top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
40
+
41
+ # Define class names for your dataset (modify as needed)
42
+ class_names = [
43
+ "Acral Lick Dermatitis",
44
+ "Acute moist dermatitis",
45
+ "Canine atopic dermatitis",
46
+ "Cherry Eye",
47
+ "Ear infections",
48
+ "External Parasites",
49
+ "Folliculitis",
50
+ "Healthy",
51
+ "Leishmaniasis",
52
+ "Lupus",
53
+ "Nuclear sclerosis",
54
+ "Otitis externa",
55
+ "Pruritus",
56
+ "Pyoderma",
57
+ "Rabies",
58
+ "Ringworm",
59
+ "Sarcoptic Mange",
60
+ "Sebaceous adenitis",
61
+ "Seborrhea",
62
+ "Skin tumor"
63
+ ]
64
+
65
+ # Return the top N class names and their probabilities in JSON format
66
+ response_data = [
67
+ {
68
+ "class_name": class_names[idx],
69
+ "probability": prob
70
+ }
71
+ for idx, prob in zip(top_indices, top_probabilities)
72
+ ]
73
+
74
+ return {"predictions": response_data}