datasets: | |
- garythung/trashnet | |
pipeline_tag: image-classification | |
to load this state model use this step: | |
#define the model | |
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) | |
for param in model_resnet.parameters(): | |
param.requires_grad = False | |
num_ftrs = model_resnet.fc.in_features | |
model_resnet.fc = nn.Linear(num_ftrs, 6) | |
# Load the weights | |
state_dict = torch.load('trashnet_resnet50.pth') | |
model.load_state_dict(state_dict) | |
# Switch to evaluation mode | |
model.eval() | |