import torch import torch.nn as nn from huggingface_hub import snapshot_download from peft import get_peft_model, LoraConfig from transformers import LlamaForCausalLM, LlamaTokenizer def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self class LeoAgentLLM(nn.Module): def __init__(self, cfg): super().__init__() # LLM if cfg.launch_mode == 'hf': llm_cfg_path = snapshot_download(cfg.model.llm.hf_cfg_path) else: llm_cfg_path = cfg.model.llm.local_cfg_path self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_cfg_path, use_fast=False, truncation_side=cfg.model.llm.truncation_side) self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.llm_tokenizer.add_special_tokens({'bos_token': ''}) self.llm_tokenizer.add_special_tokens({'eos_token': ''}) self.llm_tokenizer.add_special_tokens({'unk_token': ''}) self.llm_model = LlamaForCausalLM.from_pretrained(llm_cfg_path, torch_dtype=torch.float16) self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) for param in self.llm_model.parameters(): param.requires_grad = False self.llm_model.eval() self.llm_model.train = disabled_train # LoRA-based LLM fine-tuning if cfg.model.llm.lora.flag: lora_config = LoraConfig( r=cfg.model.llm.lora.rank, lora_alpha=cfg.model.llm.lora.alpha, target_modules=cfg.model.llm.lora.target_modules, lora_dropout=cfg.model.llm.lora.dropout, bias='none', modules_to_save=[], ) self.llm_model = get_peft_model(self.llm_model, peft_config=lora_config) self.max_context_len = cfg.model.llm.max_context_len @property def device(self): return list(self.parameters())[0].device def build_right_justified_sequence(self, data_dict): """ Concat six sequences: `prompt_before_obj`, `prompt_middle_1`, `img_tokens`, `prompt_middle_2`, `obj_tokens`, `prompt_after_obj`. Return right justified sequence for causal LM: , , , , . """ bs = len(data_dict['prompt_before_obj']) self.llm_tokenizer.padding_side = 'left' text_input_tokens_pre = self.llm_tokenizer( data_dict['prompt_before_obj'], return_tensors='pt', padding='longest' ).to(self.device) # [PAD, BOS, tokens], (B, T1) text_input_tokens_mid1 = self.llm_tokenizer( data_dict['prompt_middle_1'], return_tensors='pt', padding='longest' ).to(self.device) img_tokens = data_dict['img_tokens'].to(self.device) img_masks = data_dict['img_masks'].to(self.device) img_masks = img_masks.reshape(-1, 1).repeat(1, img_tokens.size(1)) text_input_tokens_mid2 = self.llm_tokenizer( data_dict['prompt_middle_2'], return_tensors='pt', padding='longest' ).to(self.device) obj_tokens = data_dict['obj_tokens'].to(self.device) obj_masks = data_dict['obj_masks'].to(self.device) self.llm_tokenizer.padding_side = 'right' # no need to be 'left', as padding tokens will be shifted self.llm_tokenizer.truncation_side = 'left' # truncate history text_input_tokens_post = self.llm_tokenizer( data_dict['prompt_after_obj'], return_tensors='pt', padding='longest', truncation=True, max_length=self.max_context_len, ).to(self.device) # [BOS, tokens, PAD], (B, T3) # hardcode, remove bos, make "tokenize subseq and concat" equivalent to "tokenize the whole seq" assert text_input_tokens_mid1.attention_mask.all() and text_input_tokens_mid2.attention_mask.all(), \ "prompt_middle should be the same and thus no padding" text_input_tokens_mid1.input_ids = text_input_tokens_mid1.input_ids[:, 1:] text_input_tokens_mid1.attention_mask = text_input_tokens_mid1.attention_mask[:, 1:] for i in range(bs): if not img_masks[i].any(): # no image input, also mask the text prompt for image tokens text_input_tokens_mid1.attention_mask[i].fill_(0) text_input_tokens_mid2.input_ids[:, 0] = 869 # 1 (bos) -> 869 (▁.) text_input_tokens_post.input_ids[:, 0] = 869 # 1 (bos) -> 869 (▁.) inputs_embeds_pre = self.llm_model.get_input_embeddings()(text_input_tokens_pre.input_ids) inputs_embeds_mid1 = self.llm_model.get_input_embeddings()(text_input_tokens_mid1.input_ids) inputs_embeds_mid2 = self.llm_model.get_input_embeddings()(text_input_tokens_mid2.input_ids) inputs_embeds_post = self.llm_model.get_input_embeddings()(text_input_tokens_post.input_ids) # since img_tokens, prompt_mid, obj_tokens are fixed length without padding, we concat them first inputs_embeds_mid = torch.cat([inputs_embeds_mid1, img_tokens, inputs_embeds_mid2, obj_tokens], dim=1) attn_mask_mid = torch.cat([ text_input_tokens_mid1.attention_mask, img_masks, text_input_tokens_mid2.attention_mask, obj_masks ], dim=1) post_pad_length = torch.logical_not(text_input_tokens_post.attention_mask).sum(-1) bs, l1, hidden_dim = inputs_embeds_pre.shape _, l2, _ = inputs_embeds_mid.shape _, l3, _ = inputs_embeds_post.shape inputs_embeds = torch.zeros( bs, l1+l2+l3, hidden_dim ).type(inputs_embeds_pre.dtype).to(self.device) attention_mask = torch.zeros( bs, l1+l2+l3 ).type(obj_masks.dtype).to(self.device) # assign by chunks for i in range(bs): post_pad_len = post_pad_length[i] if post_pad_len > 0: inputs_embeds[i, :post_pad_len] = inputs_embeds_post[i, -post_pad_len:] attention_mask[i, :post_pad_len] = 0 inputs_embeds[i, post_pad_len+l1+l2:] = inputs_embeds_post[i, :-post_pad_len] attention_mask[i, post_pad_len+l1+l2:] = 1 else: # no padding inputs_embeds[i, -l3:] = inputs_embeds_post[i] attention_mask[i, -l3:] = 1 inputs_embeds[i, post_pad_len: post_pad_len+l1] = inputs_embeds_pre[i] attention_mask[i, post_pad_len: post_pad_len+l1] = text_input_tokens_pre.attention_mask[i] inputs_embeds[i, post_pad_len+l1: post_pad_len+l1+l2] = inputs_embeds_mid[i] attention_mask[i, post_pad_len+l1: post_pad_len+l1+l2] = attn_mask_mid[i] return inputs_embeds, attention_mask @torch.no_grad() def generate( self, data_dict, use_nucleus_sampling=False, num_beams=5, max_length=256, min_length=1, repetition_penalty=3.0, length_penalty=1, num_captions=1, temperature=1, ): assert 'img_tokens' in data_dict and 'obj_tokens' in data_dict, "Visual features should have been processed offline." inputs_embeds, attention_mask = self.build_right_justified_sequence(data_dict=data_dict) bs = inputs_embeds.shape[0] # give bos token as condition bos_tokens = self.llm_tokenizer( [self.llm_tokenizer.bos_token] * bs, return_tensors='pt', ).to(self.device) bos_tokens_ids = bos_tokens.input_ids[:, 0:1] # (B, 1) bos_tokens_attn = bos_tokens.attention_mask[:, 0:1] # (B, 1) # prepare a `bos_token` bos_embeds = self.llm_model.get_input_embeddings()(bos_tokens_ids) # (B, 1, D) inputs_embeds = torch.cat([inputs_embeds, bos_embeds], dim=1) # (B, T1+O+T2+1, D) attention_mask = torch.cat([attention_mask, bos_tokens_attn], dim=1) # (B, T1+O+T2+1) outputs = self.llm_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=use_nucleus_sampling, temperature=temperature, num_beams=num_beams, max_length=max_length, min_length=min_length, repetition_penalty=repetition_penalty, length_penalty=length_penalty, num_return_sequences=num_captions, ) outputs[outputs == 0] = 2 # convert output id 0 (unk_token) to 2 (eos_token) output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) output_text = [text.strip() for text in output_text] return output_text