Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,9 +3,19 @@ import fastapi
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
app = fastapi.FastAPI(docs_url="/")
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
#
|
9 |
model = ViTForImageClassification.from_pretrained(
|
10 |
'google/vit-base-patch16-224-in21k',
|
11 |
num_labels=num_classes # Specify the number of classes
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
|
6 |
+
class TorchTensor(torch.Tensor):
|
7 |
+
pass
|
8 |
+
|
9 |
+
class Prediction:
|
10 |
+
prediction: TorchTensor
|
11 |
+
|
12 |
app = fastapi.FastAPI(docs_url="/")
|
13 |
+
from transformers import ViTForImageClassification
|
14 |
+
|
15 |
+
# Define the number of classes in your custom dataset
|
16 |
+
num_classes = 20
|
17 |
|
18 |
+
# Initialize the ViTForImageClassification model
|
19 |
model = ViTForImageClassification.from_pretrained(
|
20 |
'google/vit-base-patch16-224-in21k',
|
21 |
num_labels=num_classes # Specify the number of classes
|