from typing import List from queue import Queue import torch from PIL import Image from copy import deepcopy import requests, os IMAGE_TOKEN_INDEX=-200 blacklist = ['', '', ''] max_num_images = 3 # phi has a context length limit of 2048 and each image occupies 576 tokens. def input_moderation(texts: list[list[str]]): # perform input moderation on each message for text_pair in texts: # in-place operation for b in blacklist: text_pair[0] = text_pair[0].replace(b, '') if text_pair[1] is not None: text_pair[1] = text_pair[1].replace(b, '') return texts def insert_image_placeholder(t, num_images, placeholder='', sep='\n'): for _ in range(num_images): t = f"{placeholder}{sep}" + t return t def get_conv(texts): ret = [] for conv in texts: ret.append({'from': 'human', 'value': conv[0]}) ret.append({'from': 'gpt', 'value': conv[1]}) # this is None for the last one return ret # copied from llava def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def preprocess(tokenizer, data: list, return_tensors='pt'): ''' [ { 'from': 'human', 'value': xxx, }, { 'from': 'gpt', 'value': xxx } ] ''' # needs update if not isinstance(data, list): raise ValueError('must be a list') # this is per model (tokenizer) return preprocess_allava(tokenizer, data, return_tensors=return_tensors) def preprocess_vicuna_v1(self, convs: list, return_tensors) -> list: # tokenize and concat the coversations input_ids = None for ind, conv in enumerate(convs): if ind % 2 == 0: # human h = conv['value'].strip() h = f"USER: {h} " cur_input_ids = self.tokenizer_image_token(prompt=h, return_tensors=return_tensors) if input_ids is None: input_ids = cur_input_ids else: input_ids = torch.cat([input_ids, cur_input_ids]) else: # gpt g = conv['value'] if g is not None: cur_input_ids = self.tokenizer(f"ASSISTANT: {g}", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] input_ids = torch.cat([input_ids, cur_input_ids]) else: cur_input_ids = self.tokenizer(f"ASSISTANT:", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] input_ids = torch.cat([input_ids, cur_input_ids]) return input_ids def preprocess_allava(tokenizer, convs: list, return_tensors) -> list: # tokenize and concat the coversations input_ids = None for ind, conv in enumerate(convs): if ind % 2 == 0: # human h = conv['value'].strip() h = f"[INST] {h} [/INST] " cur_input_ids = tokenizer_image_token(prompt=h, tokenizer=tokenizer, return_tensors=return_tensors) if input_ids is None: input_ids = cur_input_ids else: input_ids = torch.cat([input_ids, cur_input_ids]) else: # gpt g = conv['value'] if g is not None: cur_input_ids = tokenizer(f"{g}{tokenizer.eos_token}", add_special_tokens= False, truncation=True, return_tensors='pt').input_ids[0] input_ids = torch.cat([input_ids, cur_input_ids]) return input_ids # copied from llava def get_image_tensors(processor, images, device): list_image_tensors = [] crop_size = processor.crop_size for fp in images: if fp is None: # None is used as a placeholder list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(device)) continue elif isinstance(fp, str): image = Image.open(fp).convert('RGB') elif isinstance(fp, Image.Image): image = fp # already an image else: raise TypeError(f'Unsupported type {type(fp)}') # this is the way of preprocessing images we used in training, so we impose it here if True: # self.data_args.image_aspect_ratio == 'pad' def expand2square(pil_img, background_color): width, height = pil_img.size if pil_img.mode == 'L': pil_img = pil_img.convert('RGB') if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] # a tensor list_image_tensors.append(image.to(device)) # list_image_tensors.append(image) return list_image_tensors def build_allava_input(tokenizer, processor, texts, images, history=None, return_history=False, device='cuda'): ''' texts: [[]] ''' ############################ # 1. preprocess texts ############################ if isinstance(texts, str): texts = [[texts, None]] else: assert isinstance(texts, list) and isinstance(texts[0], list) , 'texts must be a list of list' if history is not None: texts = history + texts # concat them together texts = input_moderation(texts) ############################ # 2. preprocess images ############################ if isinstance(images, str) or isinstance(images, Image.Image): images = [images] valid_images = [] if images is None: images = [None] for img in images: try: if os.path.exists(img): # make sure that the path exists img = Image.open(img).convert('RGB') else: # else it must be a URL img = Image.open(requests.get(img, stream=True).raw) valid_images.append(img) except: continue images = valid_images if images == []: images = [None] assert len(images) < max_num_images, f'Currently at most {max_num_images} images are supported' ############################ # 3. collate conv ############################ history = deepcopy(texts) # history is the texts without placeholders # insert image_place_holder_inserted = insert_image_placeholder(texts[0][0], len(images) if None not in images else 0) # only insert the placeholders for user input at the 1st round texts[0][0] = image_place_holder_inserted # collate strings into conv conv = get_conv(texts) # make input ids input_ids = preprocess(tokenizer, conv, return_tensors='pt').unsqueeze(0).to(device) list_image_tensors = get_image_tensors(processor, images, device) image_tensors = torch.stack(list_image_tensors) try: dtype = torch.bfloat16 # if your hardware does not support bf16, the following line raises an error torch.tensor(1, dtype=dtype).cuda() except: # default using fp16 dtype = torch.float16 if return_history: return input_ids, image_tensors, history return input_ids, image_tensors, None class TextIterStreamer: def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.skip_special_tokens = skip_special_tokens self.tokens = [] self.text_queue = Queue() self.next_tokens_are_prompt = True def put(self, value): if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False else: if len(value.shape) > 1: value = value[0] self.tokens.extend(value.tolist()) self.text_queue.put( self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) def end(self): self.text_queue.put(None) def __iter__(self): return self def __next__(self): value = self.text_queue.get() if value is None: raise StopIteration() else: return value