Spaces:
Running
Running
File size: 3,485 Bytes
9850295 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig
#%% set up model
class SegVol(nn.Module):
def __init__(self,
image_encoder,
mask_decoder,
prompt_encoder,
clip_ckpt,
roi_size,
patch_size,
test_mode=False,
):
super().__init__()
self.image_encoder = image_encoder
self.mask_decoder = mask_decoder
self.prompt_encoder = prompt_encoder
self.text_encoder = TextEncoder(clip_ckpt)
self.feat_shape = np.array(roi_size)/np.array(patch_size)
self.test_mode = test_mode
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
bs = image.shape[0]
img_shape = (image.shape[2], image.shape[3], image.shape[4])
image_embedding, _ = self.image_encoder(image)
image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
# test mode
if self.test_mode:
return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
# train mode
# future release
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
with torch.no_grad():
if boxes is not None:
if len(boxes.shape) == 2:
boxes = boxes[:, None, :] # (B, 1, 6)
if text is not None:
text_embedding = self.text_encoder(text) # (B, 768)
else:
text_embedding = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=boxes,
masks=None,
text_embedding=text_embedding,
)
dense_pe = self.prompt_encoder.get_dense_pe()
low_res_masks, _ = self.mask_decoder(
image_embeddings=image_embedding,
text_embedding = text_embedding,
image_pe=dense_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
return logits
class TextEncoder(nn.Module):
def __init__(self, clip_ckpt):
super().__init__()
config = CLIPTextConfig()
self.clip_text_model = CLIPTextModel(config)
self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt)
self.dim_align = nn.Linear(512, 768)
# freeze text encoder
for param in self.clip_text_model.parameters():
param.requires_grad = False
def organ2tokens(self, organ_names):
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
return tokens
def forward(self, text):
if text is None:
return None
if type(text) is str:
text = [text]
tokens = self.organ2tokens(text)
clip_outputs = self.clip_text_model(**tokens)
text_embedding = clip_outputs.pooler_output
text_embedding = self.dim_align(text_embedding)
return text_embedding
|