ALLaVA-StableLM2-1_6B / generation_utils.py
g-h-chen's picture
upload generation_utils.py
270e869 verified
raw
history blame
9.71 kB
from typing import List
from queue import Queue
import torch
from PIL import Image
from copy import deepcopy
import requests, os
IMAGE_TOKEN_INDEX=-200
blacklist = ['<image>', '<s>', '</s>']
max_num_images = 3 # phi has a context length limit of 2048 and each image occupies 576 tokens.
def input_moderation(texts: list[list[str]]):
# perform input moderation on each message
for text_pair in texts:
# in-place operation
for b in blacklist:
text_pair[0] = text_pair[0].replace(b, '')
if text_pair[1] is not None:
text_pair[1] = text_pair[1].replace(b, '')
return texts
def insert_image_placeholder(t, num_images, placeholder='<image>', sep='\n'):
for _ in range(num_images):
t = f"{placeholder}{sep}" + t
return t
def get_conv(texts):
ret = []
for conv in texts:
ret.append({'from': 'human', 'value': conv[0]})
ret.append({'from': 'gpt', 'value': conv[1]}) # this is None for the last one
return ret
# copied from llava
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('<image>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def preprocess(tokenizer, data: list, return_tensors='pt'):
'''
[
{
'from': 'human',
'value': xxx,
},
{
'from': 'gpt',
'value': xxx
}
]
'''
# needs update
if not isinstance(data, list):
raise ValueError('must be a list')
# this is per model (tokenizer)
return preprocess_allava(tokenizer, data, return_tensors=return_tensors)
def preprocess_vicuna_v1(self, convs: list, return_tensors) -> list: # tokenize and concat the coversations
input_ids = None
for ind, conv in enumerate(convs):
if ind % 2 == 0: # human
h = conv['value'].strip()
h = f"USER: {h} "
cur_input_ids = self.tokenizer_image_token(prompt=h, return_tensors=return_tensors)
if input_ids is None:
input_ids = cur_input_ids
else:
input_ids = torch.cat([input_ids, cur_input_ids])
else: # gpt
g = conv['value']
if g is not None:
cur_input_ids = self.tokenizer(f"ASSISTANT: {g}</s>", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0]
input_ids = torch.cat([input_ids, cur_input_ids])
else:
cur_input_ids = self.tokenizer(f"ASSISTANT:", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0]
input_ids = torch.cat([input_ids, cur_input_ids])
return input_ids
def preprocess_allava(tokenizer, convs: list, return_tensors) -> list: # tokenize and concat the coversations
input_ids = None
for ind, conv in enumerate(convs):
if ind % 2 == 0: # human
h = conv['value'].strip()
h = f"[INST] {h} [/INST] "
cur_input_ids = tokenizer_image_token(prompt=h, tokenizer=tokenizer, return_tensors=return_tensors)
if input_ids is None:
input_ids = cur_input_ids
else:
input_ids = torch.cat([input_ids, cur_input_ids])
else: # gpt
g = conv['value']
if g is not None:
cur_input_ids = tokenizer(f"{g}{tokenizer.eos_token}", add_special_tokens= False, truncation=True, return_tensors='pt').input_ids[0]
input_ids = torch.cat([input_ids, cur_input_ids])
return input_ids
# copied from llava
def get_image_tensors(processor, images, device):
list_image_tensors = []
crop_size = processor.crop_size
for fp in images:
if fp is None: # None is used as a placeholder
list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(device))
continue
elif isinstance(fp, str):
image = Image.open(fp).convert('RGB')
elif isinstance(fp, Image.Image):
image = fp # already an image
else:
raise TypeError(f'Unsupported type {type(fp)}')
# this is the way of preprocessing images we used in training, so we impose it here
if True:
# self.data_args.image_aspect_ratio == 'pad'
def expand2square(pil_img, background_color):
width, height = pil_img.size
if pil_img.mode == 'L':
pil_img = pil_img.convert('RGB')
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] # a tensor
list_image_tensors.append(image.to(device))
# list_image_tensors.append(image)
return list_image_tensors
def build_allava_input(tokenizer, processor, texts, images, history=None, return_history=False, device='cuda'):
'''
texts: [[]]
'''
############################
# 1. preprocess texts
############################
if isinstance(texts, str):
texts = [[texts, None]]
else:
assert isinstance(texts, list) and isinstance(texts[0], list) , 'texts must be a list of list'
if history is not None:
texts = history + texts # concat them together
texts = input_moderation(texts)
############################
# 2. preprocess images
############################
if isinstance(images, str) or isinstance(images, Image.Image):
images = [images]
valid_images = []
if images is None:
images = [None]
for img in images:
try:
if os.path.exists(img): # make sure that the path exists
img = Image.open(img).convert('RGB')
else: # else it must be a URL
img = Image.open(requests.get(img, stream=True).raw)
valid_images.append(img)
except:
continue
images = valid_images
if images == []:
images = [None]
assert len(images) < max_num_images, f'Currently at most {max_num_images} images are supported'
############################
# 3. collate conv
############################
history = deepcopy(texts) # history is the texts without <image> placeholders
# insert <image>
image_place_holder_inserted = insert_image_placeholder(texts[0][0], len(images) if None not in images else 0) # only insert the placeholders for user input at the 1st round
texts[0][0] = image_place_holder_inserted
# collate strings into conv
conv = get_conv(texts)
# make input ids
input_ids = preprocess(tokenizer, conv, return_tensors='pt').unsqueeze(0).to(device)
list_image_tensors = get_image_tensors(processor, images, device)
image_tensors = torch.stack(list_image_tensors)
try:
dtype = torch.bfloat16
# if your hardware does not support bf16, the following line raises an error
torch.tensor(1, dtype=dtype).cuda()
except:
# default using fp16
dtype = torch.float16
if return_history:
return input_ids, image_tensors, history
return input_ids, image_tensors, None
class TextIterStreamer:
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
self.text_queue = Queue()
self.next_tokens_are_prompt = True
def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
self.text_queue.put(
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
def end(self):
self.text_queue.put(None)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value