|
import albumentations as A
|
|
from transformers import PreTrainedModel
|
|
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
|
|
from .train import SphereClassifier
|
|
from .configuration_cetacean_classifier import CetaceanClassifierConfig
|
|
|
|
|
|
WHALE_CLASSES = np.array(
|
|
[
|
|
"beluga",
|
|
"blue_whale",
|
|
"bottlenose_dolphin",
|
|
"brydes_whale",
|
|
"commersons_dolphin",
|
|
"common_dolphin",
|
|
"cuviers_beaked_whale",
|
|
"dusky_dolphin",
|
|
"false_killer_whale",
|
|
"fin_whale",
|
|
"frasiers_dolphin",
|
|
"gray_whale",
|
|
"humpback_whale",
|
|
"killer_whale",
|
|
"long_finned_pilot_whale",
|
|
"melon_headed_whale",
|
|
"minke_whale",
|
|
"pantropic_spotted_dolphin",
|
|
"pygmy_killer_whale",
|
|
"rough_toothed_dolphin",
|
|
"sei_whale",
|
|
"short_finned_pilot_whale",
|
|
"southern_right_whale",
|
|
"spinner_dolphin",
|
|
"spotted_dolphin",
|
|
"white_sided_dolphin",
|
|
]
|
|
)
|
|
|
|
|
|
class CetaceanClassifierModelForImageClassification(PreTrainedModel):
|
|
config_class = CetaceanClassifierConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.model = SphereClassifier(cfg=config.to_dict())
|
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
self.config = config
|
|
self.transforms = self.make_transforms(data_aug=True)
|
|
|
|
def make_transforms(self, data_aug: bool):
|
|
augments = []
|
|
if data_aug:
|
|
aug = self.config.aug
|
|
augments = [
|
|
A.RandomResizedCrop(
|
|
self.config.image_size[0],
|
|
self.config.image_size[1],
|
|
scale=(aug["crop_scale"], 1.0),
|
|
ratio=(aug["crop_l"], aug["crop_r"]),
|
|
),]
|
|
return A.Compose(augments)
|
|
|
|
def preprocess_image(self, img) -> torch.Tensor:
|
|
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
image = cv2.resize(rgb, self.config.image_size, interpolation=cv2.INTER_CUBIC)
|
|
image = self.transforms(image=image)["image"]
|
|
return torch.Tensor(image).transpose(2, 0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, img, labels=None):
|
|
tensor = self.preprocess_image(img)
|
|
head_id_logits, head_species_logits = self.model(tensor)
|
|
head_species_logits = head_species_logits.detach().numpy()
|
|
sorted_idx = head_species_logits.argsort()[0]
|
|
sorted_idx = np.array(list(reversed(sorted_idx)))
|
|
top_three_logits = sorted_idx[:3]
|
|
top_three_whale_preds = WHALE_CLASSES[top_three_logits]
|
|
|
|
return {"predictions": top_three_whale_preds}
|
|
|