import torch import numpy.typing as npt import torch.nn.functional as F from app.configs import DEVICE from app.mobile_sam import SamPredictor from .model import point_selection, MaskWeights from .loss import calculate_dice_loss, calculate_sigmoid_focal_loss def train( predictor: SamPredictor, ref_images: list[npt.NDArray], ref_masks: list[npt.NDArray], lr: float = 1e-3, epochs: int = 200, ) -> tuple[torch.Tensor, torch.Tensor]: gt_masks = [] points = [] target_feats = [] for ref_image, ref_mask in zip(ref_images, ref_masks): gt_mask = torch.from_numpy(ref_mask)[:, :] > 0 gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to(DEVICE) gt_masks.append(gt_mask) # Image features encoding predictor.set_image(ref_image) ref_mask = predictor.get_mask(ref_mask[:, :, None]) ref_feat = predictor.features.squeeze().permute(1, 2, 0) ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0:2], mode="bilinear") ref_mask = ref_mask.squeeze() # Target feature extraction target_feat = ref_feat[ref_mask > 0] target_feat_mean = target_feat.mean(0) target_feat_max = torch.max(target_feat, dim=0)[0] target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0) # Cosine similarity h, w, C = ref_feat.shape target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True) target_feats.append(target_feat) ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True) ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w) sim = target_feat @ ref_feat sim = sim.reshape(1, 1, h, w) sim = F.interpolate(sim, scale_factor=4, mode="bilinear") sim = predictor.model.postprocess_masks( sim, input_size=predictor.input_size, original_size=predictor.original_size ).squeeze() # Positive location prior topk_xy, topk_label = point_selection(sim, topk=1) points.append((topk_xy, topk_label)) target_feat = torch.concat(target_feats, axis=0).mean(axis=0) # Learnable mask weights mask_weights = MaskWeights().to(DEVICE) mask_weights.train() optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=lr, eps=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) for _ in range(epochs): for i in range(len(gt_masks)): gt_mask = gt_masks[i] topk_xy, topk_label = points[i] # Run the decoder ( logits_high, _, _, ) = predictor.predict( point_coords=topk_xy, point_labels=topk_label, multimask_output=True, return_logits=True, return_numpy=False, ) logits_high = logits_high.flatten(1) # Weighted sum three-scale masks weights = torch.cat( (1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0, ) logits_high = logits_high * weights logits_high = logits_high.sum(0).unsqueeze(0) dice_loss = calculate_dice_loss(logits_high, gt_mask) focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask, alpha=1.0) loss = dice_loss + focal_loss optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # current_lr = scheduler.get_last_lr()[0] mask_weights.eval() weights = torch.cat( (1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0 ) return weights, target_feat