Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ from torchvision import transforms
|
|
3 |
from transformers import AutoModelForImageClassification
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
-
from model import vit
|
7 |
|
8 |
def predict(inp):
|
9 |
inputs = data_transforms(inp)[None]
|
@@ -14,40 +13,14 @@ def predict(inp):
|
|
14 |
confidences = {labels[i]: probs[0][i] for i in range(num_classes)}
|
15 |
return confidences
|
16 |
|
17 |
-
"""height=28
|
18 |
-
width=28
|
19 |
-
batch_size=128
|
20 |
-
n_channels=3
|
21 |
-
patch_size=14
|
22 |
-
dim=384
|
23 |
-
n_head=12
|
24 |
-
feed_forward=1024
|
25 |
-
num_blocks=8"""
|
26 |
-
height=224
|
27 |
-
batch_size=128
|
28 |
-
width=224
|
29 |
-
n_channels=3
|
30 |
-
patch_size=16
|
31 |
-
dim=256
|
32 |
-
n_head=8
|
33 |
-
feed_forward=512
|
34 |
-
num_blocks=12
|
35 |
-
num_classes=2
|
36 |
data_transforms = transforms.Compose([
|
37 |
-
transforms.Resize((
|
38 |
transforms.ToTensor(), # Convert images to tensors
|
39 |
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image data
|
40 |
])
|
41 |
|
42 |
-
|
43 |
-
model.
|
44 |
-
torch.load(f="vit_model.pt",
|
45 |
-
map_location=torch.device("cpu")) # load to CPU
|
46 |
-
)
|
47 |
-
print(model.state_dict())
|
48 |
-
"""labels = [
|
49 |
-
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
|
50 |
-
]"""
|
51 |
labels = [
|
52 |
'cat','dog'
|
53 |
]
|
|
|
3 |
from transformers import AutoModelForImageClassification
|
4 |
import gradio as gr
|
5 |
import torch
|
|
|
6 |
|
7 |
def predict(inp):
|
8 |
inputs = data_transforms(inp)[None]
|
|
|
13 |
confidences = {labels[i]: probs[0][i] for i in range(num_classes)}
|
14 |
return confidences
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
data_transforms = transforms.Compose([
|
17 |
+
transforms.Resize((224,224)), # Resize the images to a specific size
|
18 |
transforms.ToTensor(), # Convert images to tensors
|
19 |
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image data
|
20 |
])
|
21 |
|
22 |
+
# Load model directly
|
23 |
+
model = AutoModelForImageClassification.from_pretrained("Manu8/vit_cats-vs-dogs", trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
labels = [
|
25 |
'cat','dog'
|
26 |
]
|