Manu8 commited on
Commit
2d13268
1 Parent(s): 183c23b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -30
app.py CHANGED
@@ -3,7 +3,6 @@ from torchvision import transforms
3
  from transformers import AutoModelForImageClassification
4
  import gradio as gr
5
  import torch
6
- from model import vit
7
 
8
  def predict(inp):
9
  inputs = data_transforms(inp)[None]
@@ -14,40 +13,14 @@ def predict(inp):
14
  confidences = {labels[i]: probs[0][i] for i in range(num_classes)}
15
  return confidences
16
 
17
- """height=28
18
- width=28
19
- batch_size=128
20
- n_channels=3
21
- patch_size=14
22
- dim=384
23
- n_head=12
24
- feed_forward=1024
25
- num_blocks=8"""
26
- height=224
27
- batch_size=128
28
- width=224
29
- n_channels=3
30
- patch_size=16
31
- dim=256
32
- n_head=8
33
- feed_forward=512
34
- num_blocks=12
35
- num_classes=2
36
  data_transforms = transforms.Compose([
37
- transforms.Resize((height,width)), # Resize the images to a specific size
38
  transforms.ToTensor(), # Convert images to tensors
39
  #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image data
40
  ])
41
 
42
- model = vit(height,width,n_channels,patch_size,batch_size,dim,n_head,feed_forward,num_blocks,num_classes)# Load saved weights
43
- model.load_state_dict(
44
- torch.load(f="vit_model.pt",
45
- map_location=torch.device("cpu")) # load to CPU
46
- )
47
- print(model.state_dict())
48
- """labels = [
49
- 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'
50
- ]"""
51
  labels = [
52
  'cat','dog'
53
  ]
 
3
  from transformers import AutoModelForImageClassification
4
  import gradio as gr
5
  import torch
 
6
 
7
  def predict(inp):
8
  inputs = data_transforms(inp)[None]
 
13
  confidences = {labels[i]: probs[0][i] for i in range(num_classes)}
14
  return confidences
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  data_transforms = transforms.Compose([
17
+ transforms.Resize((224,224)), # Resize the images to a specific size
18
  transforms.ToTensor(), # Convert images to tensors
19
  #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize the image data
20
  ])
21
 
22
+ # Load model directly
23
+ model = AutoModelForImageClassification.from_pretrained("Manu8/vit_cats-vs-dogs", trust_remote_code=True)
 
 
 
 
 
 
 
24
  labels = [
25
  'cat','dog'
26
  ]