linear_code / surgvlp_linear_evaluation.py
KKYYKK's picture
Upload surgvlp_linear_evaluation.py with huggingface_hub
2af06ab verified
import torch.nn as nn
import argparse
import torch
import clip
from PIL import Image
import sys
sys.path.append('../../../')
from codes.datasets import build_dataset
from codes.models import build_algorithm
from mmengine.config import Config
from transformers import AutoTokenizer
from baselines.utils import calc_accuracy, calc_f1
import torchmetrics
import numpy as np
from torch.utils.data import ConcatDataset
import torch.optim as optim
def process_text(text):
tokenizer_clinical = AutoTokenizer.from_pretrained('/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000')
ixtoword = {v: k for k, v in tokenizer_clinical.get_vocab().items()}
if type(text) == str:
text = [text]
processed_text_tensors = []
for t in text:
text_tensors = tokenizer_clinical(
t,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=77,
)
text_tensors["sent"] = [
ixtoword[ix] for ix in text_tensors["input_ids"][0].tolist()
]
processed_text_tensors.append(text_tensors)
caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors])
attention_mask = torch.stack(
[x["attention_mask"] for x in processed_text_tensors]
)
token_type_ids = torch.stack(
[x["token_type_ids"] for x in processed_text_tensors]
)
if len(text) == 1:
caption_ids = caption_ids.squeeze(0).cuda()
attention_mask = attention_mask.squeeze(0).cuda()#.to(device)
token_type_ids = token_type_ids.squeeze(0).cuda()
else:
caption_ids = caption_ids.squeeze().cuda()
attention_mask = attention_mask.squeeze().cuda()
token_type_ids = token_type_ids.squeeze().cuda()
cap_lens = []
for txt in text:
cap_lens.append(len([w for w in txt if not w.startswith("[")]))
return {
"input_ids": caption_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"cap_lens": cap_lens,
}
def test(classifier, test_loader, model, args):
class_prompt=args.class_prompt
model.eval()
with open(class_prompt) as f:
lines = f.readlines()
f.close()
class_texts = [i.replace('\n', '') for i in lines]
class_texts = process_text(class_texts)
text_features = model(None, class_texts, mode='text')['text_emb'].cuda()
text_features /= text_features.norm(dim=-1, keepdim=True)
total_acc = []
total_f1_phase = []
total_f1_phase_class = []
with torch.no_grad():
for test_loader in test_loaders:
probs_list = []
label_list = []
for i, data in enumerate(test_loader):
frames = data['video'].cuda() # (1, M, T, C, H, W)
# B, M, T, C, H, W = frames.shape
B, C, H, W = frames.shape
frames = frames.view(-1, C, H, W)
image_features = model(frames, None, mode='video')['img_emb'] # (B*M*T, D)
probs = classifier(image_features)
# probs = probs / probs.norm(dim=-1, keepdim=True)
# probs = probs @ text_features.to(dtype=torch.float32).T
probs = probs.softmax(dim=-1) # (1, classes)
labels = data['label'].cuda()
probs_list.append(probs)
label_list.append(labels)
#
probs_list = torch.cat(probs_list, 0)
labels = torch.cat(label_list, 0)
acc = calc_accuracy(probs_list, labels)
print('accuracy: ', acc)
f1_class, f1_average = calc_f1(probs_list, labels)
print('f1 average: ', f1_average)
print('f1 classes: ', f1_class)
total_acc.append(acc)
total_f1_phase.append(f1_average)
print('f1 phase video-wise average ', np.mean(np.asarray(total_f1_phase)))
print('Acc video-wise average ', np.mean(np.asarray(total_acc)))
def linear_evaluation(
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
model: torch.nn.Module,
num_classes: int
) -> torch.nn.Module:
# Freeze the pre-trained model's parameters
for param in model.parameters():
param.requires_grad = False
class_prompt=args.class_prompt
with open(class_prompt) as f:
lines = f.readlines()
f.close()
class_texts = [i.replace('\n', '') for i in lines]
class_texts = process_text(class_texts)
text_features = model(None, class_texts, mode='text')['text_emb'].cuda()
text_features /= text_features.norm(dim=-1, keepdim=True).to(dtype=torch.float32)
# Create a linear classifier
classifier = nn.Linear(2048, num_classes).cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=0.0005)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40)
# Training loop
model.eval() # Ensure the model is in evaluation mode
for epoch in range(25):
for batch in train_loader:
inputs = batch['video'].cuda()
labels = batch['label'].cuda()
# Forward pass through the pre-trained model to get features
with torch.no_grad():
features = model(inputs, None, mode='video')['img_emb'] # (B*M*T, D)
features = features.to(dtype=torch.float32)
# Forward pass through the classifier
outputs = classifier(features)
# outputs_feat = outputs_feat / outputs_feat.norm(dim=-1, keepdim=True)
# outputs = outputs_feat @ text_features.T
loss = criterion(outputs, labels)
print(loss)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
# scheduler.step()
# Validation can be added here if needed
# classifier = classifier.eval()
# test(classifier, test_loaders, model, args)
# classifier = classifier.train()
return classifier # Return the trained classifier
def get_args(description='CLIP'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--class_prompt', default='../class_prompt.txt', type=str, help='prompt for categories')
parser.add_argument('--dataset_config', default='./config.py', type=str, help='dataset config')
parser.add_argument('--batch_size', default=1, type=int, help='batch for testing')
parser.add_argument('--num_class', default=12, type=int, help='class for classification')
parser.add_argument('--checkpoint', default='', type=str, help='Checkpoint to load')
args = parser.parse_args()
return args, parser
import torch.distributed as dist
if __name__ == "__main__":
args, _ = get_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
configs = Config.fromfile(args.dataset_config)['config']
model = build_algorithm(configs.model_config).cuda()
###### load weights
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/epoch0917.pth.tar')['state_dict']
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_3/epoch0089.pth.tar')['state_dict']
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best/epoch0200_archive.pth.tar')['state_dict']
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_3/epoch0111.pth.tar')['state_dict'] # Action+Phase
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4/epoch0170.pth.tar')['state_dict']
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_test_4/epoch0500.pth.tar')['state_dict']
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4_rewrite/epoch0250.pth.tar')['state_dict'] ### HecVL
# state_dict = torch.load('/gpfswork/rech/okw/ukw13bv/mmsl/configs/Hierarchy_SurgVLP_best_4_rewrite_spell_1/epoch0120.pth.tar')['state_dict'] ### NIPS
state_dict = torch.load(args.checkpoint)['state_dict']
new_dict = {}
for k, v in state_dict.items():
if 'module.' in k:
new_dict[k[7:].replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder')] = v
# .replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder') # for old version of model, convert keys
a, b = model.load_state_dict(new_dict, strict=True)
# print(1, a)
# print(2, b)
model.eval()
train_datasets = [build_dataset(c) for c in configs.train_config]
train_dataset = ConcatDataset(train_datasets)
val_datasets = [build_dataset(c) for c in configs.val_config]
val_dataset = ConcatDataset(val_datasets)
test_datasets = [build_dataset(c) for c in configs.test_config]
# 40 videos --> 40 datasets
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=4
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=4
)
test_loaders = [torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
num_workers=0
) for test_dataset in test_datasets] # 40 dataloaders
print(args)
classifier = linear_evaluation(train_loader, val_loader, model, args.num_class)
test(classifier, test_loaders, model, args)