from PIL import Image import torch from transformers import StoppingCriteria, StoppingCriteriaList from enum import auto, Enum import numpy as np from decord import VideoReader, cpu import torchvision.transforms as T from models.video_transformers import ( GroupNormalize, GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor ) from torchvision.transforms.functional import InterpolationMode from transformers import LlamaTokenizer, LlamaConfig device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() def get_prompt(conv): ret = conv.system + conv.sep for role, message in conv.messages: if message: ret += role + ": " + message + conv.sep else: ret += role + ":" return ret class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all((stop == input_ids[0][-len(stop):])).item(): return True return False class Chat: def __init__(self, model, device='cuda:0'): self.device = device self.model = model stop_words_ids = [torch.tensor([835]).to(self.device), torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def ask(self,text,conv): conv.messages.append([conv.roles[0], text + '\n']) return conv def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0): conv.messages.append([conv.roles[1], None]) embs = self.get_context_emb(conv, img_list) outputs = self.model.llama_model.generate( inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria=self.stopping_criteria, num_beams=num_beams, do_sample=True, min_length=min_length, top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, ) output_token = outputs[0] if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it output_token = output_token[1:] if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it output_token = output_token[1:] output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) output_text = output_text.split('###')[0] # remove the stop sign '###' output_text = output_text.split('Assistant:')[-1].strip() conv.messages[-1][1] = output_text return output_text, output_token.cpu().numpy(), conv def get_index(self, num_frames, num_segments): seg_size = float(num_frames - 1) / num_segments start = int(seg_size / 2) offsets = np.array([ start + int(np.round(seg_size * idx)) for idx in range(num_segments) ]) return offsets def load_video(self, video_path, num_segments=8, return_msg=False): vr = VideoReader(video_path, ctx=cpu(0)) num_frames = len(vr) frame_indices = self.get_index(num_frames, num_segments) duration = len(vr) // vr.get_avg_fps() index = np.linspace(0, len(vr)-1, num=int(duration)) buffer = vr.get_batch(index).asnumpy() # transform input_mean = [0.48145466, 0.4578275, 0.40821073] input_std = [0.26862954, 0.26130258, 0.27577711] transform = T.Compose([ GroupScale(int(224), interpolation=InterpolationMode.BICUBIC), GroupCenterCrop(224), Stack(), ToTorchFormatTensor(), GroupNormalize(input_mean, input_std) ]) images_group = list() for frame in buffer: img = Image.fromarray(frame) images_group.append(img) images_group = list() for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].asnumpy()) images_group.append(img) torch_imgs_224 = transform(images_group) if return_msg: fps = float(vr.get_avg_fps()) sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) # " " should be added in the start and end msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." return torch_imgs_224, msg else: return torch_imgs_224 def upload_video(self, image, conv, img_list, num_segments): if isinstance(image, str): # is a image path vid_chat, msg = self.load_video(image, num_segments=num_segments, return_msg=True) TC, H, W = vid_chat.shape image = vid_chat.reshape(1, TC//3, 3, H, W).to(self.device) else: raise NotImplementedError print("Input video shape:", vid_chat.shape) image_emb, _ = self.model.encode_img(image) img_list.append(image_emb) conv.messages.append([ conv.roles[0], f" {msg}\n" ]) msg = "Received." # self.conv.append_message(self.conv.roles[1], msg) return msg, img_list, conv def upload_img(self, image, conv, img_list): img = image#Image.open(image)#.convert('RGB') transform = T.Compose( [ T.Resize( (224, 224), interpolation=InterpolationMode.BICUBIC ), T.ToTensor(), T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ] ) img = transform(img).unsqueeze(0).unsqueeze(0).cuda() image_emb, _ = self.model.encode_img(img) img_list.append(image_emb) conv.messages.append([ conv.roles[0], f"\n" ]) msg = "Received." # self.conv.append_message(self.conv.roles[1], msg) return msg,img_list, conv def get_context_emb(self, conv, img_list): prompt = get_prompt(conv) #print(prompt) if '' in prompt: prompt_segs = prompt.split('') else: prompt_segs = prompt.split('') assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of visual placeholders and videos." seg_tokens = [ self.model.llama_tokenizer( seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids # only add bos to the first seg for i, seg in enumerate(prompt_segs) ] seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] mixed_embs = torch.cat(mixed_embs, dim=1) return mixed_embs