RadiXGPT_ / main.py
Singularity666's picture
Update main.py
61e3710
from torch import nn
from tqdm.autonotebook import tqdm
from transformers import AutoTokenizer, AutoModel
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
import albumentations as A
import cv2
import timm
import torch
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CFG:
debug = False
image_path = '/content/content/new_images_v5'
captions_path = '/content/content/all_data/new_caption.csv'
batch_size = 12
num_workers = 2
head_lr = 1e-3
image_encoder_lr = 1e-4
text_encoder_lr = 1e-5
weight_decay = 1e-3
patience = 1
factor = 0.8
epochs = 2
saved_model_clinical = '/content/content/new_weights.pt'
trained_model = 'clinical_bert_weights.pt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
clinical_encoder_model = "emilyalsentzer/Bio_ClinicalBERT"
text_embedding = 768
text_tokenizer = "distilbert-base-uncased"
max_length = 200
pretrained = True # for both image encoder and text encoder
trainable = True # for both image encoder and text encoder
temperature = 1.0
# image size
size = 224
# for projection head; used for both image and text encoders
num_projection_layers = 1
projection_dim = 256
dropout = 0.1
def build_loaders(dataframe, tokenizer, mode):
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
class AvgMeter:
def __init__(self, name="Metric"):
self.name = name
self.reset()
def reset(self):
self.avg, self.sum, self.count = [0] * 3
def update(self, val, count=1):
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
text = f"{self.name}: {self.avg:.4f}"
return text
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]
# Custom dataset object. Will tokenize text and apply transforms to images before yielding them.
class CLIPDataset(torch.utils.data.Dataset):
def __init__(self, image_filenames, captions, tokenizer, transforms):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
file names
"""
self.image_filenames = image_filenames
self.captions = list(captions)
self.skippedImgCount = 0
self.encoded_captions = tokenizer(
list(captions), padding=True, truncation=True, max_length=CFG.max_length
)
self.transforms = transforms
def __getitem__(self, idx):
item = {
key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items()
}
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
if image is None:
# Skip the current example and move to the next one
self.skippedImgCount += 1
return self.__getitem__((idx + 1) % len(self))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transforms(image=image)['image']
item['image'] = torch.tensor(image).permute(2, 0, 1).float()
item['caption'] = self.captions[idx]
return item
def __len__(self):
return len(self.captions)
def get_transforms(mode="train"):
if mode == "train":
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
else:
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
class ImageEncoder(nn.Module):
"""
Encode images to a fixed size vector
"""
def __init__(
self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
):
super().__init__()
self.model = timm.create_model(
model_name, pretrained, num_classes=0, global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = trainable
def forward(self, x):
return self.model(x)
class TextEncoder(nn.Module):
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
super().__init__()
if pretrained:
# self.model = DistilBertModel.from_pretrained(model_name)
# Use Bio-ClinicalBERT
self.model = AutoModel.from_pretrained(CFG.clinical_encoder_model)
else:
self.model = DistilBertModel(config=DistilBertConfig())
for p in self.model.parameters():
p.requires_grad = trainable
# we are using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]
# Get both image and text encodings into a same size matrix
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim=CFG.projection_dim,
dropout=CFG.dropout
):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
class CLIPModel(nn.Module):
def __init__(
self,
temperature=CFG.temperature,
image_embedding=CFG.image_embedding,
text_embedding=CFG.text_embedding,
):
super().__init__()
self.image_encoder = ImageEncoder()
self.text_encoder = TextEncoder()
self.image_projection = ProjectionHead(embedding_dim=image_embedding)
self.text_projection = ProjectionHead(embedding_dim=text_embedding)
self.temperature = temperature
def forward(self, batch):
# Getting Image and Text Features
image_features = self.image_encoder(batch["image"])
text_features = self.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
# Getting Image and Text Embeddings (with same dimension)
image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)
# Calculating the Loss
logits = (text_embeddings @ image_embeddings.T) / self.temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
# INFERENCE CODE
def get_image_embeddings(image):
# preprocess the image
if image is None:
print("Image not found!")
return None
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = get_transforms("valid")(image=image)['image']
image = image.reshape(3, 224, 224)
model = CLIPModel().to(device)
model.load_state_dict(torch.load('weights.pt', map_location=device))
model.eval()
with torch.no_grad():
image_tensor = torch.from_numpy(image)
image_features = model.image_encoder(image_tensor.unsqueeze(0).to(device))
image_embeddings = model.image_projection(image_features)
image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
return image_embeddings
def predict_caption(image, model, text_embeddings, captions, n=1):
# get the image embeddings
image_embeddings = get_image_embeddings(image)
if image_embeddings is None:
return None
# normalize the embeddings
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
# calculate the dot product of image and text embeddings
dot_similarity = image_embeddings_n @ text_embeddings_n.T
# get the top n matches
values, indices = torch.topk(dot_similarity.squeeze(0), n)
indices = indices.cpu().numpy().tolist()
matches = [captions[idx] for idx in indices]
return matches
def get_text_embeddings(valid_df):
tokenizer = AutoTokenizer.from_pretrained(CFG.clinical_encoder_model)
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(device)
model.load_state_dict(torch.load("weights.pt", map_location=device))
model.eval()
valid_text_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
text_features = model.text_encoder(
input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device)
)
text_embeddings = model.text_projection(text_features)
valid_text_embeddings.append(text_embeddings)
return model, torch.cat(valid_text_embeddings)