SANJAYV10 commited on
Commit
c305981
1 Parent(s): 719a218

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -3,9 +3,19 @@ import fastapi
3
  import numpy as np
4
  from PIL import Image
5
 
 
 
 
 
 
 
6
  app = fastapi.FastAPI(docs_url="/")
 
 
 
 
7
 
8
- # Load your pre-trained model and other necessary components here
9
  model = ViTForImageClassification.from_pretrained(
10
  'google/vit-base-patch16-224-in21k',
11
  num_labels=num_classes # Specify the number of classes
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ class TorchTensor(torch.Tensor):
7
+ pass
8
+
9
+ class Prediction:
10
+ prediction: TorchTensor
11
+
12
  app = fastapi.FastAPI(docs_url="/")
13
+ from transformers import ViTForImageClassification
14
+
15
+ # Define the number of classes in your custom dataset
16
+ num_classes = 20
17
 
18
+ # Initialize the ViTForImageClassification model
19
  model = ViTForImageClassification.from_pretrained(
20
  'google/vit-base-patch16-224-in21k',
21
  num_labels=num_classes # Specify the number of classes