Spaces:
Running
Running
import copy | |
import numpy as np | |
import time | |
import torch | |
import torchvision.transforms.functional as F | |
import matplotlib.pyplot as plt | |
from modules.eval import main_evaluation | |
from torch.optim import SGD, AdamW | |
from torchvision.models.detection import keypointrcnn_resnet50_fpn | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
from tqdm import tqdm | |
from modules.utils import write_results | |
def get_arrow_model(num_classes, num_keypoints=2): | |
""" | |
Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints. | |
Parameters: | |
- num_classes (int): Number of classes for the model to detect, excluding the background class. | |
- num_keypoints (int): Number of keypoints to predict for each detected object. | |
Returns: | |
- model (torch.nn.Module): The modified Keypoint R-CNN model. | |
Steps: | |
1. Load a pre-trained Keypoint R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN). | |
The model is initially configured for the COCO dataset, which includes various object classes and keypoints. | |
2. Replace the box predictor to adjust the number of output classes. The box predictor is responsible for | |
classifying detected regions and predicting their bounding boxes. | |
3. Replace the keypoint predictor to adjust the number of keypoints the model predicts for each object. | |
This is necessary to tailor the model to specific tasks that may have different keypoint structures. | |
""" | |
# Load a model pre-trained on COCO, initialized without pre-trained weights | |
model = keypointrcnn_resnet50_fpn(weights=None) | |
# Get the number of input features for the classifier in the box predictor. | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
# Replace the box predictor in the ROI heads with a new one, tailored to the number of classes. | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
# Replace the keypoint predictor in the ROI heads with a new one, specifically designed for the desired number of keypoints. | |
model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(512, num_keypoints) | |
return model | |
def get_faster_rcnn_model(num_classes): | |
""" | |
Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes. | |
Parameters: | |
- num_classes (int): Number of classes for the model to detect, including the background class. | |
Returns: | |
- model (torch.nn.Module): The modified Faster R-CNN model. | |
""" | |
# Load a pre-trained Faster R-CNN model | |
model = fasterrcnn_resnet50_fpn(weights=None) | |
# Get the number of input features for the classifier in the box predictor | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
# Replace the box predictor with a new one, tailored to the number of classes (num_classes includes the background) | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
return model | |
def prepare_model(dict,opti,learning_rate= 0.0003,model_to_load=None, model_type = 'object'): | |
# Adjusted to pass the class_dict directly | |
if model_type == 'object': | |
model = get_faster_rcnn_model(len(dict)) | |
elif model_type == 'arrow': | |
model = get_arrow_model(len(dict),2) | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
# Load the model weights | |
if model_to_load: | |
model.load_state_dict(torch.load('./models/'+ model_to_load +'.pth', map_location=device)) | |
print(f"Model '{model_to_load}' loaded") | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
model.to(device) | |
if opti == 'SGD': | |
#learning_rate= 0.002 | |
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001) | |
elif opti == 'Adam': | |
#learning_rate = 0.0003 | |
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.00056, eps=1e-08, betas=(0.9, 0.999)) | |
else: | |
print('Optimizer not found') | |
return model, optimizer, device | |
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False): | |
model.train() # Set the model to evaluation mode | |
total_loss = 0 | |
# Initialize lists to keep track of individual losses | |
loss_classifier_list = [] | |
loss_box_reg_list = [] | |
loss_objectness_list = [] | |
loss_rpn_box_reg_list = [] | |
loss_keypoints_list = [] | |
with torch.no_grad(): # Disable gradient computation | |
for images, targets_im in tqdm(data_loader, desc="Evaluating"): | |
images = [image.to(device) for image in images] | |
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im] | |
loss_dict = model(images, targets) | |
# Calculate the total loss for the current batch | |
losses = 0 | |
if loss_config is not None: | |
for key, loss in loss_dict.items(): | |
if loss_config.get(key, False): | |
losses += loss | |
else: | |
losses = sum(loss for key, loss in loss_dict.items()) | |
total_loss += losses.item() | |
# Collect individual losses | |
if loss_dict.get('loss_classifier') is not None: | |
loss_classifier_list.append(loss_dict['loss_classifier'].item()) | |
else: | |
loss_classifier_list.append(0) | |
if loss_dict.get('loss_box_reg') is not None: | |
loss_box_reg_list.append(loss_dict['loss_box_reg'].item()) | |
else: | |
loss_box_reg_list.append(0) | |
if loss_dict.get('loss_objectness') is not None: | |
loss_objectness_list.append(loss_dict['loss_objectness'].item()) | |
else: | |
loss_objectness_list.append(0) | |
if loss_dict.get('loss_rpn_box_reg') is not None: | |
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item()) | |
else: | |
loss_rpn_box_reg_list.append(0) | |
if 'loss_keypoint' in loss_dict: | |
loss_keypoints_list.append(loss_dict['loss_keypoint'].item()) | |
else: | |
loss_keypoints_list.append(0) | |
# Calculate average loss | |
avg_loss = total_loss / len(data_loader) | |
avg_loss_classifier = np.mean(loss_classifier_list) | |
avg_loss_box_reg = np.mean(loss_box_reg_list) | |
avg_loss_objectness = np.mean(loss_objectness_list) | |
avg_loss_rpn_box_reg = np.mean(loss_rpn_box_reg_list) | |
avg_loss_keypoints = np.mean(loss_keypoints_list) | |
if print_losses: | |
print(f"Average Loss: {avg_loss:.4f}") | |
print(f"Average Classifier Loss: {avg_loss_classifier:.4f}") | |
print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}") | |
print(f"Average Objectness Loss: {avg_loss_objectness:.4f}") | |
print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}") | |
print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}") | |
return avg_loss | |
def training_model(num_epochs, model, data_loader, subset_test_loader, | |
optimizer, model_to_load=None, change_learning_rate=5, start_key=30, | |
batch_size=4, crop_prob=0.2, h_flip_prob=0.3, v_flip_prob=0.3, | |
max_rotate_deg=20, rotate_proba=0.2, blur_prob=0.2, | |
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97, | |
information_training='training', start_epoch=0, loss_config=None, model_type = 'object', | |
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')): | |
if loss_config is None: | |
print('No loss config found, all losses will be used.') | |
else: | |
#print the list of the losses that will be used | |
print('The following losses will be used: ', end='') | |
for key, value in loss_config.items(): | |
if value: | |
print(key, end=", ") | |
print() | |
# Initialize lists to store epoch-wise average losses | |
epoch_avg_losses = [] | |
epoch_avg_loss_classifier = [] | |
epoch_avg_loss_box_reg = [] | |
epoch_avg_loss_objectness = [] | |
epoch_avg_loss_rpn_box_reg = [] | |
epoch_avg_loss_keypoints = [] | |
epoch_precision = [] | |
epoch_recall = [] | |
epoch_f1_score = [] | |
epoch_test_loss = [] | |
start_tot = time.time() | |
best_metrics = -1000 | |
best_epoch = 0 | |
best_model_state = None | |
same = 0 | |
learning_rate = optimizer.param_groups[0]['lr'] | |
bad_test_loss = 0 | |
previous_test_loss = 1000 | |
print(f"Let's go training {model_type} model with {num_epochs} epochs!") | |
print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, Flip prob: {h_flip_prob}, Rotate prob: {rotate_proba}, Blur prob: {blur_prob}") | |
for epoch in range(num_epochs): | |
if (epoch>0 and (epoch)%change_learning_rate == 0) or bad_test_loss>1: | |
learning_rate = 0.7*learning_rate | |
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999)) | |
print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}') | |
bad_test_loss = 0 | |
if epoch>0 and (epoch)==start_key: | |
print("Now it's training Keypoints also") | |
loss_config['loss_keypoint'] = True | |
for name, param in model.named_parameters(): | |
if 'keypoint' in name: | |
param.requires_grad = True | |
model.train() | |
start = time.time() | |
total_loss = 0 | |
# Initialize lists to keep track of individual losses | |
loss_classifier_list = [] | |
loss_box_reg_list = [] | |
loss_objectness_list = [] | |
loss_rpn_box_reg_list = [] | |
loss_keypoints_list = [] | |
# Create a tqdm progress bar | |
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}') | |
for images, targets_im in progress_bar: | |
images = [image.to(device) for image in images] | |
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im] | |
optimizer.zero_grad() | |
loss_dict = model(images, targets) | |
# Inside the training loop where losses are calculated: | |
losses = 0 | |
if loss_config is not None: | |
for key, loss in loss_dict.items(): | |
if loss_config.get(key, False): | |
if key == 'loss_classifier': | |
loss *= 3 | |
losses += loss | |
else: | |
losses = sum(loss for key, loss in loss_dict.items()) | |
# Collect individual losses | |
if loss_dict['loss_classifier']: | |
loss_classifier_list.append(loss_dict['loss_classifier'].item()) | |
else: | |
loss_classifier_list.append(0) | |
if loss_dict['loss_box_reg']: | |
loss_box_reg_list.append(loss_dict['loss_box_reg'].item()) | |
else: | |
loss_box_reg_list.append(0) | |
if loss_dict['loss_objectness']: | |
loss_objectness_list.append(loss_dict['loss_objectness'].item()) | |
else: | |
loss_objectness_list.append(0) | |
if loss_dict['loss_rpn_box_reg']: | |
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item()) | |
else: | |
loss_rpn_box_reg_list.append(0) | |
if 'loss_keypoint' in loss_dict: | |
loss_keypoints_list.append(loss_dict['loss_keypoint'].item()) | |
else: | |
loss_keypoints_list.append(0) | |
losses.backward() | |
optimizer.step() | |
total_loss += losses.item() | |
# Update the description with the current loss | |
progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}') | |
# Calculate average loss | |
avg_loss = total_loss / len(data_loader) | |
epoch_avg_losses.append(avg_loss) | |
epoch_avg_loss_classifier.append(np.mean(loss_classifier_list)) | |
epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list)) | |
epoch_avg_loss_objectness.append(np.mean(loss_objectness_list)) | |
epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list)) | |
epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list)) | |
# Evaluate the model on the test set | |
if eval_metric != 'loss': | |
avg_test_loss = 0 | |
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type) | |
print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ") | |
if eval_metric == 'all': | |
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config) | |
print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ") | |
if eval_metric == 'loss': | |
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0,0,0,0,0,0 | |
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config) | |
print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ") | |
print(f"Time: {time.time() - start:.2f} [s]") | |
if epoch>0 and (epoch)%start_key == 0: | |
print(f"Keypoints Accuracy: {key_accuracy:.4f}", end=", ") | |
if eval_metric == 'f1_score': | |
metric_used = f1_score | |
elif eval_metric == 'precision': | |
metric_used = precision | |
elif eval_metric == 'recall': | |
metric_used = recall | |
else: | |
metric_used = -avg_test_loss | |
# Check if this epoch's model has the lowest average loss | |
if metric_used > best_metrics: | |
best_metrics = metric_used | |
best_epoch = epoch+1+start_epoch | |
best_model_state = copy.deepcopy(model.state_dict()) | |
if epoch>0 and f1_score>early_stop_f1_score: | |
same+=1 | |
epoch_precision.append(precision) | |
epoch_recall.append(recall) | |
epoch_f1_score.append(f1_score) | |
epoch_test_loss.append(avg_test_loss) | |
name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}" | |
if same >=1 : | |
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss] | |
torch.save(best_model_state, './models/'+ name_model +'.pth') | |
write_results(name_model,metrics_list,start_epoch) | |
break | |
if (epoch+1+start_epoch) % 5 == 0: | |
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss] | |
torch.save(best_model_state, './models/'+ name_model +'.pth') | |
model.load_state_dict(best_model_state) | |
write_results(name_model,metrics_list,start_epoch) | |
if avg_test_loss > previous_test_loss: | |
bad_test_loss += 1 | |
previous_test_loss = avg_test_loss | |
print(f"\n Total time: {(time.time() - start_tot)/60} minutes, Best Epoch is {best_epoch} with an f1_score of {best_metrics:.4f}") | |
if best_model_state: | |
metrics_list = [epoch_avg_losses,epoch_avg_loss_classifier,epoch_avg_loss_box_reg,epoch_avg_loss_objectness,epoch_avg_loss_rpn_box_reg,epoch_avg_loss_keypoints,epoch_precision,epoch_recall,epoch_f1_score,epoch_test_loss] | |
torch.save(best_model_state, './models/'+ name_model +'.pth') | |
model.load_state_dict(best_model_state) | |
write_results(name_model,metrics_list,start_epoch) | |
print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}") | |
return model, metrics_list |