Spaces:
Build error
Build error
import torch | |
import numpy.typing as npt | |
import torch.nn.functional as F | |
from app.configs import DEVICE | |
from app.mobile_sam import SamPredictor | |
from .model import point_selection, MaskWeights | |
from .loss import calculate_dice_loss, calculate_sigmoid_focal_loss | |
def train( | |
predictor: SamPredictor, | |
ref_images: list[npt.NDArray], | |
ref_masks: list[npt.NDArray], | |
lr: float = 1e-3, | |
epochs: int = 200, | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
gt_masks = [] | |
points = [] | |
target_feats = [] | |
for ref_image, ref_mask in zip(ref_images, ref_masks): | |
gt_mask = torch.from_numpy(ref_mask)[:, :] > 0 | |
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to(DEVICE) | |
gt_masks.append(gt_mask) | |
# Image features encoding | |
predictor.set_image(ref_image) | |
ref_mask = predictor.get_mask(ref_mask[:, :, None]) | |
ref_feat = predictor.features.squeeze().permute(1, 2, 0) | |
ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0:2], mode="bilinear") | |
ref_mask = ref_mask.squeeze() | |
# Target feature extraction | |
target_feat = ref_feat[ref_mask > 0] | |
target_feat_mean = target_feat.mean(0) | |
target_feat_max = torch.max(target_feat, dim=0)[0] | |
target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0) | |
# Cosine similarity | |
h, w, C = ref_feat.shape | |
target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True) | |
target_feats.append(target_feat) | |
ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True) | |
ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w) | |
sim = target_feat @ ref_feat | |
sim = sim.reshape(1, 1, h, w) | |
sim = F.interpolate(sim, scale_factor=4, mode="bilinear") | |
sim = predictor.model.postprocess_masks( | |
sim, input_size=predictor.input_size, original_size=predictor.original_size | |
).squeeze() | |
# Positive location prior | |
topk_xy, topk_label = point_selection(sim, topk=1) | |
points.append((topk_xy, topk_label)) | |
target_feat = torch.concat(target_feats, axis=0).mean(axis=0) | |
# Learnable mask weights | |
mask_weights = MaskWeights().to(DEVICE) | |
mask_weights.train() | |
optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=lr, eps=1e-4) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) | |
for _ in range(epochs): | |
for i in range(len(gt_masks)): | |
gt_mask = gt_masks[i] | |
topk_xy, topk_label = points[i] | |
# Run the decoder | |
( | |
logits_high, | |
_, | |
_, | |
) = predictor.predict( | |
point_coords=topk_xy, | |
point_labels=topk_label, | |
multimask_output=True, | |
return_logits=True, | |
return_numpy=False, | |
) | |
logits_high = logits_high.flatten(1) | |
# Weighted sum three-scale masks | |
weights = torch.cat( | |
(1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), | |
dim=0, | |
) | |
logits_high = logits_high * weights | |
logits_high = logits_high.sum(0).unsqueeze(0) | |
dice_loss = calculate_dice_loss(logits_high, gt_mask) | |
focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask, alpha=1.0) | |
loss = dice_loss + focal_loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
# current_lr = scheduler.get_last_lr()[0] | |
mask_weights.eval() | |
weights = torch.cat( | |
(1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0 | |
) | |
return weights, target_feat | |