small fix
Browse files
app.py
CHANGED
|
@@ -19,6 +19,7 @@ class GradioApp:
|
|
| 19 |
custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
|
| 20 |
|
| 21 |
pretrained = models.vit_b_16().to(device).eval()
|
|
|
|
| 22 |
pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
|
| 23 |
|
| 24 |
self.models: Dict[str, Union[str, nn.Module]] = {
|
|
|
|
| 19 |
custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
|
| 20 |
|
| 21 |
pretrained = models.vit_b_16().to(device).eval()
|
| 22 |
+
pretrained.heads = nn.Linear(768, 3)
|
| 23 |
pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
|
| 24 |
|
| 25 |
self.models: Dict[str, Union[str, nn.Module]] = {
|