dillonlaird's picture
initial commit
6723494
raw
history blame
No virus
3.78 kB
import torch
import numpy.typing as npt
import torch.nn.functional as F
from app.configs import DEVICE
from app.mobile_sam import SamPredictor
from .model import point_selection, MaskWeights
from .loss import calculate_dice_loss, calculate_sigmoid_focal_loss
def train(
predictor: SamPredictor,
ref_images: list[npt.NDArray],
ref_masks: list[npt.NDArray],
lr: float = 1e-3,
epochs: int = 200,
) -> tuple[torch.Tensor, torch.Tensor]:
gt_masks = []
points = []
target_feats = []
for ref_image, ref_mask in zip(ref_images, ref_masks):
gt_mask = torch.from_numpy(ref_mask)[:, :] > 0
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).to(DEVICE)
gt_masks.append(gt_mask)
# Image features encoding
predictor.set_image(ref_image)
ref_mask = predictor.get_mask(ref_mask[:, :, None])
ref_feat = predictor.features.squeeze().permute(1, 2, 0)
ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0:2], mode="bilinear")
ref_mask = ref_mask.squeeze()
# Target feature extraction
target_feat = ref_feat[ref_mask > 0]
target_feat_mean = target_feat.mean(0)
target_feat_max = torch.max(target_feat, dim=0)[0]
target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)
# Cosine similarity
h, w, C = ref_feat.shape
target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
target_feats.append(target_feat)
ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
sim = target_feat @ ref_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim, input_size=predictor.input_size, original_size=predictor.original_size
).squeeze()
# Positive location prior
topk_xy, topk_label = point_selection(sim, topk=1)
points.append((topk_xy, topk_label))
target_feat = torch.concat(target_feats, axis=0).mean(axis=0)
# Learnable mask weights
mask_weights = MaskWeights().to(DEVICE)
mask_weights.train()
optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=lr, eps=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
for _ in range(epochs):
for i in range(len(gt_masks)):
gt_mask = gt_masks[i]
topk_xy, topk_label = points[i]
# Run the decoder
(
logits_high,
_,
_,
) = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True,
return_logits=True,
return_numpy=False,
)
logits_high = logits_high.flatten(1)
# Weighted sum three-scale masks
weights = torch.cat(
(1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights),
dim=0,
)
logits_high = logits_high * weights
logits_high = logits_high.sum(0).unsqueeze(0)
dice_loss = calculate_dice_loss(logits_high, gt_mask)
focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask, alpha=1.0)
loss = dice_loss + focal_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
# current_lr = scheduler.get_last_lr()[0]
mask_weights.eval()
weights = torch.cat(
(1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0
)
return weights, target_feat