# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, visit # https://github.com/NVlabs/prismer/blob/main/LICENSE import json import torch.nn as nn from model.modules.vit import load_encoder from model.modules.roberta import load_decoder from transformers import RobertaTokenizer, RobertaConfig class Prismer(nn.Module): def __init__(self, config): super().__init__() self.experts = {'rgb': 3} for exp in config['experts']: if exp in ['depth', 'edge']: self.experts[exp] = 1 elif exp in ['normal']: self.experts[exp] = 3 elif 'seg' in exp: self.experts['seg'] = 64 elif exp in ['obj_detection', 'ocr_detection']: self.experts[exp] = 64 prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']] roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model']) self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name']) self.expert_encoder = load_encoder(prismer_config['vit_model'], experts=self.experts, image_resolution=config['image_resolution']) self.text_decoder = load_decoder(prismer_config['roberta_model']['model_name'], config=roberta_config) self.prepare_to_train(config['freeze']) self.ignored_modules = self.get_ignored_modules(config['freeze']) def prepare_to_train(self, mode='none'): for name, params in self.named_parameters(): if mode == 'freeze_lang': if 'encoder.layer' in name and all(key not in name for key in ['1.self', '1.output', 'adaptor']): params.requires_grad = False else: params.requires_grad = True elif mode == 'freeze_vision': if 'transformer.resblocks' in name and 'adaptor' not in name: params.requires_grad = False else: params.requires_grad = True elif mode == 'freeze_lang_vision': if 'encoder.layer' in name and all(key not in name for key in ['1.self', '1.output', 'adaptor']): params.requires_grad = False elif 'transformer.resblocks' in name and 'adaptor' not in name: params.requires_grad = False else: params.requires_grad = True else: params.requires_grad = True def get_ignored_modules(self, mode='none'): ignored_modules = [] if mode == 'freeze_lang': for l in range(len(self.text_decoder.roberta.encoder.layer)): ignored_modules += [ self.text_decoder.roberta.encoder.layer[l][0].attention, self.text_decoder.roberta.encoder.layer[l][0].intermediate, self.text_decoder.roberta.encoder.layer[l][0].output, ] elif mode == 'freeze_vision': for l in range(len(self.expert_encoder.transformer.resblocks)): ignored_modules += [ self.expert_encoder.transformer.resblocks[l][0].attn, self.expert_encoder.transformer.resblocks[l][0].mlp, self.expert_encoder.transformer.resblocks[l][0].ln_1, self.expert_encoder.transformer.resblocks[l][0].ln_2, ] elif mode == 'freeze_lang_vision': for l in range(len(self.text_decoder.roberta.encoder.layer)): ignored_modules += [ self.text_decoder.roberta.encoder.layer[l][0].attention, self.text_decoder.roberta.encoder.layer[l][0].intermediate, self.text_decoder.roberta.encoder.layer[l][0].output, ] for l in range(len(self.expert_encoder.transformer.resblocks)): ignored_modules += [ self.expert_encoder.transformer.resblocks[l][0].attn, self.expert_encoder.transformer.resblocks[l][0].mlp, self.expert_encoder.transformer.resblocks[l][0].ln_1, self.expert_encoder.transformer.resblocks[l][0].ln_2, ] else: ignored_modules = None return ignored_modules