import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig #%% set up model class SegVol(nn.Module): def __init__(self, image_encoder, mask_decoder, prompt_encoder, clip_ckpt, roi_size, patch_size, test_mode=False, ): super().__init__() self.image_encoder = image_encoder self.mask_decoder = mask_decoder self.prompt_encoder = prompt_encoder self.text_encoder = TextEncoder(clip_ckpt) self.feat_shape = np.array(roi_size)/np.array(patch_size) self.test_mode = test_mode def forward(self, image, text=None, boxes=None, points=None, **kwargs): bs = image.shape[0] img_shape = (image.shape[2], image.shape[3], image.shape[4]) image_embedding, _ = self.image_encoder(image) image_embedding = image_embedding.transpose(1, 2).view(bs, -1, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2])) # test mode if self.test_mode: return self.forward_decoder(image_embedding, img_shape, text, boxes, points) # train mode # future release def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None): with torch.no_grad(): if boxes is not None: if len(boxes.shape) == 2: boxes = boxes[:, None, :] # (B, 1, 6) if text is not None: text_embedding = self.text_encoder(text) # (B, 768) else: text_embedding = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=boxes, masks=None, text_embedding=text_embedding, ) dense_pe = self.prompt_encoder.get_dense_pe() low_res_masks, _ = self.mask_decoder( image_embeddings=image_embedding, text_embedding = text_embedding, image_pe=dense_pe, sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False) return logits class TextEncoder(nn.Module): def __init__(self, clip_ckpt): super().__init__() config = CLIPTextConfig() self.clip_text_model = CLIPTextModel(config) self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt) self.dim_align = nn.Linear(512, 768) # freeze text encoder for param in self.clip_text_model.parameters(): param.requires_grad = False def organ2tokens(self, organ_names): text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names] tokens = self.tokenizer(text_list, padding=True, return_tensors="pt") return tokens def forward(self, text): if text is None: return None if type(text) is str: text = [text] tokens = self.organ2tokens(text) clip_outputs = self.clip_text_model(**tokens) text_embedding = clip_outputs.pooler_output text_embedding = self.dim_align(text_embedding) return text_embedding