| import torch.nn as nn |
| import argparse |
| import torch |
| import clip |
| from PIL import Image |
| import sys |
| sys.path.append('../../../') |
| from codes.datasets import build_dataset |
| from codes.models import build_algorithm |
| from mmengine.config import Config |
| from transformers import AutoTokenizer |
| from baselines.utils import calc_accuracy, calc_f1 |
| import torchmetrics |
| import numpy as np |
| from torch.utils.data import ConcatDataset |
| import torch.optim as optim |
|
|
| def process_text(text): |
| tokenizer_clinical = AutoTokenizer.from_pretrained('/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000') |
| ixtoword = {v: k for k, v in tokenizer_clinical.get_vocab().items()} |
| if type(text) == str: |
| text = [text] |
|
|
| processed_text_tensors = [] |
| for t in text: |
|
|
| text_tensors = tokenizer_clinical( |
| t, |
| return_tensors="pt", |
| truncation=True, |
| padding="max_length", |
| max_length=77, |
| ) |
| text_tensors["sent"] = [ |
| ixtoword[ix] for ix in text_tensors["input_ids"][0].tolist() |
| ] |
| processed_text_tensors.append(text_tensors) |
|
|
| caption_ids = torch.stack([x["input_ids"] for x in processed_text_tensors]) |
| attention_mask = torch.stack( |
| [x["attention_mask"] for x in processed_text_tensors] |
| ) |
| token_type_ids = torch.stack( |
| [x["token_type_ids"] for x in processed_text_tensors] |
| ) |
|
|
| if len(text) == 1: |
| caption_ids = caption_ids.squeeze(0).cuda() |
| attention_mask = attention_mask.squeeze(0).cuda() |
| token_type_ids = token_type_ids.squeeze(0).cuda() |
| else: |
| caption_ids = caption_ids.squeeze().cuda() |
| attention_mask = attention_mask.squeeze().cuda() |
| token_type_ids = token_type_ids.squeeze().cuda() |
|
|
| cap_lens = [] |
| for txt in text: |
| cap_lens.append(len([w for w in txt if not w.startswith("[")])) |
|
|
| return { |
| "input_ids": caption_ids, |
| "attention_mask": attention_mask, |
| "token_type_ids": token_type_ids, |
| "cap_lens": cap_lens, |
| } |
|
|
| def test(classifier, test_loader, model, args): |
| class_prompt=args.class_prompt |
|
|
| model.eval() |
|
|
| with open(class_prompt) as f: |
| lines = f.readlines() |
| f.close() |
|
|
| class_texts = [i.replace('\n', '') for i in lines] |
| class_texts = process_text(class_texts) |
| text_features = model(None, class_texts, mode='text')['text_emb'].cuda() |
| text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
| total_acc = [] |
| total_f1_phase = [] |
| total_f1_phase_class = [] |
|
|
| with torch.no_grad(): |
| for test_loader in test_loaders: |
| probs_list = [] |
| label_list = [] |
|
|
| for i, data in enumerate(test_loader): |
| frames = data['video'].cuda() |
| |
| B, C, H, W = frames.shape |
|
|
| frames = frames.view(-1, C, H, W) |
| image_features = model(frames, None, mode='video')['img_emb'] |
|
|
| probs = classifier(image_features) |
|
|
| |
| |
|
|
| probs = probs.softmax(dim=-1) |
| labels = data['label'].cuda() |
|
|
| probs_list.append(probs) |
| label_list.append(labels) |
|
|
|
|
| |
| probs_list = torch.cat(probs_list, 0) |
| labels = torch.cat(label_list, 0) |
| |
| acc = calc_accuracy(probs_list, labels) |
| print('accuracy: ', acc) |
| f1_class, f1_average = calc_f1(probs_list, labels) |
| print('f1 average: ', f1_average) |
| print('f1 classes: ', f1_class) |
|
|
| total_acc.append(acc) |
| total_f1_phase.append(f1_average) |
| print('f1 phase video-wise average ', np.mean(np.asarray(total_f1_phase))) |
| print('Acc video-wise average ', np.mean(np.asarray(total_acc))) |
|
|
|
|
|
|
| def linear_evaluation( |
| train_loader: torch.utils.data.DataLoader, |
| val_loader: torch.utils.data.DataLoader, |
| model: torch.nn.Module, |
| num_classes: int |
| ) -> torch.nn.Module: |
| |
| for param in model.parameters(): |
| param.requires_grad = False |
| |
| class_prompt=args.class_prompt |
| with open(class_prompt) as f: |
| lines = f.readlines() |
| f.close() |
|
|
| class_texts = [i.replace('\n', '') for i in lines] |
| class_texts = process_text(class_texts) |
| text_features = model(None, class_texts, mode='text')['text_emb'].cuda() |
| text_features /= text_features.norm(dim=-1, keepdim=True).to(dtype=torch.float32) |
|
|
| |
| classifier = nn.Linear(2048, num_classes).cuda() |
| criterion = nn.CrossEntropyLoss().cuda() |
| optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=0.0005) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40) |
|
|
|
|
| |
| model.eval() |
| for epoch in range(25): |
| for batch in train_loader: |
| inputs = batch['video'].cuda() |
| labels = batch['label'].cuda() |
|
|
| |
| with torch.no_grad(): |
| features = model(inputs, None, mode='video')['img_emb'] |
|
|
| features = features.to(dtype=torch.float32) |
| |
| outputs = classifier(features) |
|
|
| |
| |
|
|
| loss = criterion(outputs, labels) |
| print(loss) |
|
|
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| |
|
|
| |
| |
| |
| |
| |
| return classifier |
|
|
| def get_args(description='CLIP'): |
| parser = argparse.ArgumentParser(description=description) |
| parser.add_argument('--class_prompt', default='../class_prompt.txt', type=str, help='prompt for categories') |
| parser.add_argument('--dataset_config', default='./config.py', type=str, help='dataset config') |
| parser.add_argument('--batch_size', default=1, type=int, help='batch for testing') |
| parser.add_argument('--num_class', default=12, type=int, help='class for classification') |
| parser.add_argument('--checkpoint', default='', type=str, help='Checkpoint to load') |
| args = parser.parse_args() |
| return args, parser |
|
|
| import torch.distributed as dist |
| if __name__ == "__main__": |
|
|
| args, _ = get_args() |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| configs = Config.fromfile(args.dataset_config)['config'] |
|
|
| model = build_algorithm(configs.model_config).cuda() |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| state_dict = torch.load(args.checkpoint)['state_dict'] |
|
|
|
|
|
|
| new_dict = {} |
| for k, v in state_dict.items(): |
| if 'module.' in k: |
| new_dict[k[7:].replace('visual.model.', 'backbone_img.model.').replace('text_module.model.', 'backbone_text.model.').replace('visual.global_embedder','backbone_img.global_embedder')] = v |
| |
| a, b = model.load_state_dict(new_dict, strict=True) |
|
|
| |
| |
|
|
| model.eval() |
|
|
| train_datasets = [build_dataset(c) for c in configs.train_config] |
| train_dataset = ConcatDataset(train_datasets) |
|
|
| val_datasets = [build_dataset(c) for c in configs.val_config] |
| val_dataset = ConcatDataset(val_datasets) |
|
|
| test_datasets = [build_dataset(c) for c in configs.test_config] |
| |
|
|
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| drop_last=False, |
| num_workers=4 |
| ) |
|
|
| val_loader = torch.utils.data.DataLoader( |
| val_dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| drop_last=False, |
| num_workers=4 |
| ) |
|
|
| test_loaders = [torch.utils.data.DataLoader( |
| test_dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| drop_last=False, |
| num_workers=0 |
| ) for test_dataset in test_datasets] |
| print(args) |
|
|
| classifier = linear_evaluation(train_loader, val_loader, model, args.num_class) |
|
|
| test(classifier, test_loaders, model, args) |
|
|
|
|