RadiXGPT_ / main.py
Singularity666's picture
Update main.py
257f974
raw
history blame
10.5 kB
from torch import nn
from tqdm.autonotebook import tqdm
from transformers import AutoTokenizer, AutoModel
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
import albumentations as A
import cv2
import timm
import torch
import torch.nn.functional as F
device = torch.device("cpu")
class CFG:
debug = False
image_path = '/content/content/new_images_v5'
captions_path = '/content/content/all_data/new_caption.csv'
batch_size = 12
num_workers = 2
head_lr = 1e-3
image_encoder_lr = 1e-4
text_encoder_lr = 1e-5
weight_decay = 1e-3
patience = 1
factor = 0.8
epochs = 2
saved_model_clinical = '/content/content/new_weights.pt'
trained_model = 'clinical_bert_weights.pt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
clinical_encoder_model = "emilyalsentzer/Bio_ClinicalBERT"
text_embedding = 768
text_tokenizer = "distilbert-base-uncased"
max_length = 200
pretrained = True # for both image encoder and text encoder
trainable = True # for both image encoder and text encoder
temperature = 1.0
# image size
size = 224
# for projection head; used for both image and text encoders
num_projection_layers = 1
projection_dim = 256
dropout = 0.1
def build_loaders(dataframe, tokenizer, mode):
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
class AvgMeter:
def __init__(self, name="Metric"):
self.name = name
self.reset()
def reset(self):
self.avg, self.sum, self.count = [0] * 3
def update(self, val, count=1):
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
text = f"{self.name}: {self.avg:.4f}"
return text
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]
# Custom dataset object. Will tokenize text and apply transforms to images before yielding them.
class CLIPDataset(torch.utils.data.Dataset):
def __init__(self, image_filenames, captions, tokenizer, transforms):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
file names
"""
self.image_filenames = image_filenames
self.captions = list(captions)
self.skippedImgCount = 0
self.encoded_captions = tokenizer(
list(captions), padding=True, truncation=True, max_length=CFG.max_length
)
self.transforms = transforms
def __getitem__(self, idx):
item = {
key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items()
}
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
if image is None:
# Skip the current example and move to the next one
self.skippedImgCount += 1
return self.__getitem__((idx + 1) % len(self))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transforms(image=image)['image']
item['image'] = torch.tensor(image).permute(2, 0, 1).float()
item['caption'] = self.captions[idx]
return item
def __len__(self):
return len(self.captions)
def get_transforms(mode="train"):
if mode == "train":
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
else:
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
class ImageEncoder(nn.Module):
"""
Encode images to a fixed size vector
"""
def __init__(
self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
):
super().__init__()
self.model = timm.create_model(
model_name, pretrained, num_classes=0, global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
return self.model(x)
class TextEncoder(nn.Module):
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
super().__init__()
if pretrained:
# self.model = DistilBertModel.from_pretrained(model_name)
# Use Bio-ClinicalBERT
self.model = AutoModel.from_pretrained(CFG.clinical_encoder_model)
else:
self.model = DistilBertModel(config=DistilBertConfig())
for p in self.model.parameters():
p.requires_grad = trainable
# we are using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]
# Get both image and text encodings into a same size matrix
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim=CFG.projection_dim,
dropout=CFG.dropout
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
class CLIPModel(nn.Module):
def __init__(
self,
temperature=CFG.temperature,
image_embedding=CFG.image_embedding,
text_embedding=CFG.text_embedding,
):
super().__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = TextEncoder()
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
self.temperature = temperature
def forward(self, batch):
# Getting Image and Text Features
image_features = self.image_encoder(batch["image"])
text_features = self.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / self.temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
# INFERENCE CODE
def get_image_embeddings(image):
# preprocess the image
if image is None:
print("Image not found!")
return None
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = get_transforms("valid")(image=image)['image']
image = image.reshape(3, 224, 224)
model = CLIPModel().to(device)
model.load_state_dict(torch.load('weights.pt', map_location=device))
model.eval()
with torch.no_grad():
image_tensor = torch.from_numpy(image)
image_features = model.image_encoder(image_tensor.unsqueeze(0).to(device))
image_embeddings = model.image_projection(image_features)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
return image_embeddings
def predict_caption(image, model, text_embeddings, captions, n=2):
# get the image embeddings
image_embeddings = get_image_embeddings(image)
if image_embeddings is None:
return None
# normalize the embeddings
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
# calculate the dot product of image and text embeddings
dot_similarity = image_embeddings_n @ text_embeddings_n.T
# get the top n matches
values, indices = torch.topk(dot_similarity.squeeze(0), n)
indices = indices.cpu().numpy().tolist()
matches = [captions[idx] for idx in indices]
return matches
def get_text_embeddings(valid_df):
tokenizer = AutoTokenizer.from_pretrained(CFG.clinical_encoder_model)
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(device)
model.load_state_dict(torch.load("weights.pt", map_location=device))
model.eval()
valid_text_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
text_features = model.text_encoder(
input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device)
)
text_embeddings = model.text_projection(text_features)
valid_text_embeddings.append(text_embeddings)
return model, torch.cat(valid_text_embeddings)