File size: 493 Bytes
bb96f79
 
 
 
92e43b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
---
datasets:
- garythung/trashnet
pipeline_tag: image-classification
---

to load this state model use this step:
#define the model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
for param in model_resnet.parameters():
        param.requires_grad = False

num_ftrs = model_resnet.fc.in_features
model_resnet.fc = nn.Linear(num_ftrs, 6)

# Load the weights
state_dict = torch.load('trashnet_resnet50.pth')
model.load_state_dict(state_dict)

# Switch to evaluation mode
model.eval()