Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
cetacean-classifier / modeling_cetacean_classifier.py
MalloryWittwerEPFL's picture
Upload model
d514464 verified
import albumentations as A
from transformers import PreTrainedModel
# from PIL import Image
import numpy as np
import torch
import cv2
from .train import SphereClassifier
from .configuration_cetacean_classifier import CetaceanClassifierConfig
WHALE_CLASSES = np.array(
[
"beluga",
"blue_whale",
"bottlenose_dolphin",
"brydes_whale",
"commersons_dolphin",
"common_dolphin",
"cuviers_beaked_whale",
"dusky_dolphin",
"false_killer_whale",
"fin_whale",
"frasiers_dolphin",
"gray_whale",
"humpback_whale",
"killer_whale",
"long_finned_pilot_whale",
"melon_headed_whale",
"minke_whale",
"pantropic_spotted_dolphin",
"pygmy_killer_whale",
"rough_toothed_dolphin",
"sei_whale",
"short_finned_pilot_whale",
"southern_right_whale",
"spinner_dolphin",
"spotted_dolphin",
"white_sided_dolphin",
]
)
class CetaceanClassifierModelForImageClassification(PreTrainedModel):
config_class = CetaceanClassifierConfig
def __init__(self, config):
super().__init__(config)
self.model = SphereClassifier(cfg=config.to_dict())
# load_from_checkpoint("cetacean_classifier/last.ckpt")
# self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
self.model.eval()
self.config = config
self.transforms = self.make_transforms(data_aug=True)
def make_transforms(self, data_aug: bool):
augments = []
if data_aug:
aug = self.config.aug
augments = [
A.RandomResizedCrop(
self.config.image_size[0],
self.config.image_size[1],
scale=(aug["crop_scale"], 1.0),
ratio=(aug["crop_l"], aug["crop_r"]),
),]
return A.Compose(augments)
def preprocess_image(self, img) -> torch.Tensor:
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = cv2.resize(rgb, self.config.image_size, interpolation=cv2.INTER_CUBIC)
image = self.transforms(image=image)["image"]
return torch.Tensor(image).transpose(2, 0).unsqueeze(0)
#image_resized = img.resize((480, 480))
#image_resized = np.array(image_resized)[None]
#image_resized = np.transpose(image_resized, [0, 3, 2, 1])
#image_tensor = torch.Tensor(image_resized)
#return image_tensor
def forward(self, img, labels=None):
tensor = self.preprocess_image(img)
head_id_logits, head_species_logits = self.model(tensor)
head_species_logits = head_species_logits.detach().numpy()
sorted_idx = head_species_logits.argsort()[0]
sorted_idx = np.array(list(reversed(sorted_idx)))
top_three_logits = sorted_idx[:3]
top_three_whale_preds = WHALE_CLASSES[top_three_logits]
return {"predictions": top_three_whale_preds}