|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from utils import CONFIG |
|
from networks import m2ms, ops |
|
import sys |
|
sys.path.insert(0, './segment-anything') |
|
from segment_anything import sam_model_registry |
|
|
|
class sam_m2m(nn.Module): |
|
def __init__(self, m2m): |
|
super(sam_m2m, self).__init__() |
|
if m2m not in m2ms.__all__: |
|
raise NotImplementedError("Unknown M2M {}".format(m2m)) |
|
self.m2m = m2ms.__dict__[m2m](nc=256) |
|
self.seg_model = sam_model_registry['vit_b'](checkpoint=None) |
|
self.seg_model.eval() |
|
|
|
def forward(self, image, guidance): |
|
self.seg_model.eval() |
|
with torch.no_grad(): |
|
feas, masks = self.seg_model.forward_m2m(image, guidance, multimask_output=True) |
|
pred = self.m2m(feas, image, masks) |
|
return pred |
|
|
|
def forward_inference(self, image_dict): |
|
self.seg_model.eval() |
|
with torch.no_grad(): |
|
feas, masks, post_masks = self.seg_model.forward_m2m_inference(image_dict, multimask_output=True) |
|
pred = self.m2m(feas, image_dict["image"], masks) |
|
return feas, pred, post_masks |
|
|
|
def get_generator_m2m(seg, m2m): |
|
if seg == 'sam': |
|
generator = sam_m2m(m2m=m2m) |
|
return generator |