Spaces:
Sleeping
Sleeping
import lightgbm as lgb | |
from pytorch_utils import * | |
from lightning_utils import * | |
from pytorch_vision_utils import * | |
models_path = Path('selected_models') | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# The order used in the model; what would be returned by [os.path.basename(p) for p in Path(fr'{data_path}\test').glob('*')] | |
classes = ['ace of clubs', 'ace of diamonds', 'ace of hearts', 'ace of spades', | |
'eight of clubs', 'eight of diamonds', 'eight of hearts', 'eight of spades', | |
'five of clubs', 'five of diamonds', 'five of hearts', 'five of spades', | |
'four of clubs', 'four of diamonds', 'four of hearts', 'four of spades', | |
'jack of clubs', 'jack of diamonds', 'jack of hearts', 'jack of spades', | |
'joker', 'king of clubs', 'king of diamonds', 'king of hearts', 'king of spades', | |
'nine of clubs', 'nine of diamonds', 'nine of hearts', 'nine of spades', | |
'queen of clubs', 'queen of diamonds', 'queen of hearts', 'queen of spades', | |
'seven of clubs', 'seven of diamonds', 'seven of hearts', 'seven of spades', | |
'six of clubs', 'six of diamonds', 'six of hearts', 'six of spades', | |
'ten of clubs', 'ten of diamonds', 'ten of hearts', 'ten of spades', | |
'three of clubs', 'three of diamonds', 'three of hearts', 'three of spades', | |
'two of clubs', 'two of diamonds', 'two of hearts', 'two of spades'] | |
### Classification-layer-retrained NN | |
# RexNet 1.5 | |
class_retrain_model_name, class_retrain_extra = 'RexNet15', '0_First_Adam001_10_epochs' | |
class_retrain_model = timm.create_model('rexnet_150.nav_in1k', pretrained = True, num_classes = 53).eval().to(device) | |
class_retrain_transforms = timm.data.create_transform(**timm.data.resolve_model_data_config(class_retrain_model), is_training = False) | |
for param in class_retrain_model.features.parameters(): param.requires_grad = False | |
for param in class_retrain_model.stem.parameters(): param.requires_grad = False | |
# model.classifier | |
class_retrain_model.load_state_dict(torch.load(models_path / f'{class_retrain_model_name}_{class_retrain_extra}.pth', map_location = device)) | |
### Fully-retrained NN | |
# RexNet 1.0 | |
# full_retrain_model_name, full_retrain_extra = 'RexNet10', '0_First_Adam001_10_epochs' | |
full_retrain_experiment_name, full_retrain_model_name, full_retrain_extra = 'FullRetrain_EarlyStop', 'RexNet10', 'Adam001_max10_epochs' | |
# full_retrain_experiment_name, full_retrain_model_name, full_retrain_extra = 'ClassRetrain_EarlyStop', 'RexNet10', 'Adam001_max10_epochs' | |
full_retrain_model = timm.create_model('rexnet_100.nav_in1k', pretrained = True, num_classes = 53).eval().to(device) | |
full_retrain_transforms = timm.data.create_transform(**timm.data.resolve_model_data_config(full_retrain_model), is_training = False) | |
for param in full_retrain_model.features.parameters(): param.requires_grad = False | |
for param in full_retrain_model.stem.parameters(): param.requires_grad = False | |
# # RexNet 1.5 | |
# full_retrain_model_name, full_retrain_extra = 'RexNet15', '0_First_Adam001_10_epochs' | |
# full_retrain_model = timm.create_model('rexnet_150.nav_in1k', pretrained = True, num_classes = 53).eval().to(device) | |
# transforms = timm.data.create_transform(**timm.data.resolve_model_data_config(full_retrain_model), is_training = False) | |
# for param in full_retrain_model.features.parameters(): param.requires_grad = False | |
# for param in full_retrain_model.stem.parameters(): param.requires_grad = False | |
full_retrain_model.load_state_dict(torch.load(models_path / f'{full_retrain_experiment_name}_{full_retrain_model_name}_{full_retrain_extra}.pth', map_location = device)) | |
# Use the pred_image_class function from pytorch_vision_utils.py | |
### NN Feature Extraction -> Gradient Boosting | |
## Import the feature extraction model | |
# feats_model_name = 'RexNet10' | |
# feats_model_name = timm.create_model('rexnet_100.nav_in1k', pretrained = True, num_classes = 53).eval().to(device) | |
feats_model_name = 'RexNet15' | |
feats_model = timm.create_model('rexnet_150.nav_in1k', pretrained = True, num_classes = 53).eval().to(device) | |
feats_transforms = timm.data.create_transform(**timm.data.resolve_model_data_config(feats_model), is_training = False) | |
# No training and only feature extraction (up to pooling after final convolution, i.e. for RexNet 1.0 [batch, 1280, 7, 7] -> [batch, 1280], and 1920 for RexNet 1.5) | |
for param in feats_model.parameters(): param.requires_grad = False | |
feats_model = nn.Sequential(OrderedDict(stem = feats_model.stem, features = feats_model.features, pool = feats_model.head.global_pool)) | |
## Import the Gradient Boosting model | |
num_iterations = 100 | |
boosting_type = 'gbdt' # 'gbdt' vs 'dart' (dart also comes with more parameters: max_drop, skip_drop, xgboost_dart_mode, uniform_drop) | |
data_sample_strategy = 'bagging' # 'bagging' vs 'goss' (goss also comes with more parameters: top_rate, other_rate) | |
gb_model = lgb.Booster(model_file = models_path / f'{feats_model_name}_features_in_{num_iterations}_{data_sample_strategy}_{boosting_type}_lgbm.txt') | |
# Processing functions | |
def image_to_features(image: Image, model: torch.nn.Module, transform: tv.transforms.Compose, | |
device: torch.device = 'cuda' if torch.cuda.is_available() else 'cpu') -> torch.Tensor: | |
model.eval() | |
with torch.inference_mode(): feats = model(transform(image).unsqueeze(0).to(device)).squeeze().to('cpu') | |
return feats | |
def gb_predict_classes(feats: torch.Tensor, model: lgb.Booster, class_names: list[str]) -> dict[str, float]: | |
'''Return the (ordered) predicted probabilities of each class for the given image | |
''' | |
probs = model.predict([feats], num_iteration = gb_model.best_iteration) # Already probabilities, not logits | |
return OrderedDict(sorted({class_names[i]: float(probs[0][i]) for i in range(len(class_names))}.items(), key = itemgetter(1), reverse = True)) | |
# class_id = torch.argmax(probs, dim = 1) | |
# return class_names[class_id.cpu()], probs.unsqueeze(0).max().cpu().item() | |