fohy24 commited on
Commit
2ded624
·
1 Parent(s): bfe52e9

update to use model v13_1_e9 for top 30 morphs prediction

Browse files
Files changed (1) hide show
  1. app.py +38 -25
app.py CHANGED
@@ -1,43 +1,56 @@
 
1
  import torch
2
  from torch import nn
3
  from torchvision import models
4
  from torchvision.transforms import v2
5
- import os
6
  from huggingface_hub import hf_hub_download
7
  import gradio as gr
8
 
9
- labels = ['Pastel',
10
- 'Yellow Belly',
11
- 'Enchi',
12
- 'Clown',
13
- 'Leopard',
14
- 'Piebald',
15
- 'Orange Dream',
16
- 'Fire',
17
- 'Mojave',
18
- 'Pinstripe',
19
- 'Banana',
20
- 'Normal',
21
- 'Black Pastel',
22
- 'Lesser',
23
- 'Spotnose',
24
- 'Cinnamon',
25
- 'GHI',
26
- 'Hypo',
27
- 'Spider',
28
- 'Super Pastel']
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  num_labels = len(labels)
30
 
31
  # If using GPU
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
 
34
  hf_token = os.getenv('HF_token')
35
- model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v9_epoch8.pt", token=hf_token)
36
  checkpoint = torch.load(model_path, map_location=device)
37
 
38
  new_layers = nn.Sequential(
39
- nn.LazyLinear(1280),
40
- nn.BatchNorm1d(1280),
41
  nn.ReLU(),
42
  nn.Dropout(0.5),
43
  nn.LazyLinear(num_labels)
@@ -47,7 +60,7 @@ IMAGE_SIZE = 480
47
  transform = v2.Compose([
48
  v2.ToImage(),
49
  v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
50
- v2.ToDtype(torch.float32, scale=True),
51
  v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
52
  ])
53
 
 
1
+ import os
2
  import torch
3
  from torch import nn
4
  from torchvision import models
5
  from torchvision.transforms import v2
 
6
  from huggingface_hub import hf_hub_download
7
  import gradio as gr
8
 
9
+ labels = [
10
+ 'Pastel',
11
+ 'Yellow Belly',
12
+ 'Enchi',
13
+ 'Clown',
14
+ 'Leopard',
15
+ 'Piebald',
16
+ 'Orange Dream',
17
+ 'Fire',
18
+ 'Mojave',
19
+ 'Pinstripe',
20
+ 'Banana',
21
+ 'Normal',
22
+ 'Black Pastel',
23
+ 'Lesser',
24
+ 'Spotnose',
25
+ 'Cinnamon',
26
+ 'GHI',
27
+ 'Hypo',
28
+ 'Spider',
29
+ 'Super Pastel',
30
+ 'Desert Ghost',
31
+ 'Black Head',
32
+ 'Vanilla',
33
+ 'Red Stripe',
34
+ 'Asphalt',
35
+ 'Gravel',
36
+ 'Butter',
37
+ 'Calico',
38
+ 'Albino',
39
+ 'Chocolate'
40
+ ]
41
+
42
  num_labels = len(labels)
43
 
44
  # If using GPU
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
 
47
  hf_token = os.getenv('HF_token')
48
+ model_path = hf_hub_download(repo_id="samfhy/morphmarket_model", filename="model_v13_1_epoch9.pt", token=hf_token)
49
  checkpoint = torch.load(model_path, map_location=device)
50
 
51
  new_layers = nn.Sequential(
52
+ nn.LazyLinear(2048),
53
+ nn.BatchNorm1d(2048),
54
  nn.ReLU(),
55
  nn.Dropout(0.5),
56
  nn.LazyLinear(num_labels)
 
60
  transform = v2.Compose([
61
  v2.ToImage(),
62
  v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
63
+ v2.ToDtype(torch.float32),
64
  v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
65
  ])
66