Manu8 commited on
Commit
183c23b
1 Parent(s): b82f566

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torchvision import transforms
3
+ from transformers import AutoModelForImageClassification
4
+ import gradio as gr
5
+ import torch
6
+ from model import vit
7
+
8
+ def predict(inp):
9
+ inputs = data_transforms(inp)[None]
10
+ model.eval()
11
+ with torch.no_grad():
12
+ logits = model(inputs)
13
+ probs = torch.softmax(logits,dim=1)
14
+ confidences = {labels[i]: probs[0][i] for i in range(num_classes)}
15
+ return confidences
16
+
17
+ """height=28
18
+ width=28
19
+ batch_size=128
20
+ n_channels=3
21
+ patch_size=14
22
+ dim=384
23
+ n_head=12
24
+ feed_forward=1024
25
+ num_blocks=8"""
26
+ height=224
27
+ batch_size=128
28
+ width=224
29
+ n_channels=3
30
+ patch_size=16
31
+ dim=256
32
+ n_head=8
33
+ feed_forward=512
34
+ num_blocks=12
35
+ num_classes=2
36
+ data_transforms = transforms.Compose([
37
+ transforms.Resize((height,width)), # Resize the images to a specific size
38
+ transforms.ToTensor(), # Convert images to tensors
39
+ #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image data
40
+ ])
41
+
42
+ model = vit(height,width,n_channels,patch_size,batch_size,dim,n_head,feed_forward,num_blocks,num_classes)# Load saved weights
43
+ model.load_state_dict(
44
+ torch.load(f="vit_model.pt",
45
+ map_location=torch.device("cpu")) # load to CPU
46
+ )
47
+ print(model.state_dict())
48
+ """labels = [
49
+ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
50
+ ]"""
51
+ labels = [
52
+ 'cat','dog'
53
+ ]
54
+ gr.Interface(fn=predict,
55
+ inputs=gr.Image(type="pil"),
56
+ outputs=gr.Label(num_top_classes=3)).launch()