File size: 8,318 Bytes
8fbc209 86744eb f16e094 8fbc209 f16e094 8fbc209 86744eb 8fbc209 669c11e 8fbc209 669c11e 8fbc209 f16e094 75c15ae 8fbc209 86744eb 8fbc209 86744eb 669c11e 86744eb 8fbc209 86744eb 335eee6 86744eb 669c11e 86744eb 669c11e 86744eb f748941 86744eb 669c11e 86744eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import PIL
import torch
from .modeling_llava import LlavaForConditionalGeneration
from .processing_llava import MLlavaProcessor
# from ..conversation import conv_mllava_v1_mmtag as default_conv
from ..conversation import conv_mllava_v1 as default_conv, conv_templates
from typing import List, Tuple, Union, Tuple
def chat_mllava(
text:str,
images: List[Union[PIL.Image.Image, str]],
model:LlavaForConditionalGeneration,
processor:MLlavaProcessor,
max_input_length:int=None,
history:List[dict]=None,
**kwargs) -> Tuple[str, List[dict]]:
"""
Chat with the Mllava model
Args:
text: str, the text to be sent to the model, where <image> will be the placeholder for the image
images: List[PIL.Image.Image], the images to be sent to the model, or None
model: LlavaForConditionalGeneration, the model to be used
processor: MLlavaProcessor, the processor to be used
max_input_length: int, the maximum input length
history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
kwargs: dict, the generation kwargs
Returns:
Tuple[str, List[dict]], the generated text and the history of the conversation
"""
if "llama-3" in model.language_model.name_or_path.lower():
conv = conv_templates['llama_3']
terminators = [
processor.tokenizer.eos_token_id,
processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
else:
conv = default_conv
terminators = None
kwargs["eos_token_id"] = terminators
conv = conv.copy()
conv.messages = []
if history is not None:
for message in history:
assert message["role"] in conv.roles
conv.append_message(message["role"], message["text"])
if text:
assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], "")
history.append({"role": conv.roles[0], "text": text})
history.append({"role": conv.roles[1], "text": ""})
else:
if conv.messages[-1][0] == conv.roles[1]:
assert conv.messages[-1][1] == "", "No user message should be provided"
else:
assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
conv.append_message(conv.roles[0], "")
history.append({"role": conv.roles[0], "text": ""})
else:
history = []
history.append({"role": conv.roles[0], "text": text})
history.append({"role": conv.roles[1], "text": ""})
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], "")
assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
prompt = conv.get_prompt()
if images:
for i in range(len(images)):
if isinstance(images[i], str):
images[i] = PIL.Image.open(images[i]).convert("RGB")
inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
for k, v in inputs.items():
if v is not None:
if isinstance(v, torch.Tensor):
inputs[k] = v.to(model.device)
elif isinstance(v, list):
inputs[k] = [x.to(model.device) for x in v]
else:
raise ValueError(f"Invalid input type: {type(v)}")
output_ids = model.generate(**inputs, **kwargs)
output_ids = output_ids[0]
# remove the input tokens
generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
generated_text = processor.decode(generated_ids, skip_special_tokens=True)
history[-1]["text"] = generated_text
return generated_text, history
def chat_mllava_stream(
text:str,
images: List[Union[PIL.Image.Image, str]],
model:LlavaForConditionalGeneration,
processor:MLlavaProcessor,
max_input_length:int=None,
history:List[dict]=None,
**kwargs) -> Tuple[str, List[dict]]:
"""
Chat with the Mllava model
Args:
text: str, the text to be sent to the model, where <image> will be the placeholder for the image
images: List[PIL.Image.Image], the images to be sent to the model, or None
model: LlavaForConditionalGeneration, the model to be used
processor: MLlavaProcessor, the processor to be used
max_input_length: int, the maximum input length
history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
kwargs: dict, the generation kwargs
Returns:
Tuple[str, List[dict]], the generated text and the history of the conversation
"""
if "llama-3" in model.language_model.name_or_path.lower():
conv = conv_templates['llama_3']
terminators = [
processor.tokenizer.eos_token_id,
processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
else:
conv = default_conv
terminators = None
kwargs["eos_token_id"] = terminators
conv = conv.copy()
conv.messages = []
if history is not None:
for message in history:
assert message["role"] in conv.roles
conv.append_message(message["role"], message["text"])
if text:
assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], "")
history.append({"role": conv.roles[0], "text": text})
history.append({"role": conv.roles[1], "text": ""})
else:
if conv.messages[-1][0] == conv.roles[1]:
assert conv.messages[-1][1] == "", "No user message should be provided"
else:
assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
conv.append_message(conv.roles[0], "")
history.append({"role": conv.roles[0], "text": ""})
else:
history = []
history.append({"role": conv.roles[0], "text": text})
history.append({"role": conv.roles[1], "text": ""})
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], "")
assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
prompt = conv.get_prompt()
if images:
for i in range(len(images)):
if isinstance(images[i], str):
images[i] = PIL.Image.open(images[i])
images[i] = images[i].convert("RGB")
inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
print(processor.tokenizer.decode(inputs["input_ids"][0]))
for k, v in inputs.items():
if v is not None:
if isinstance(v, torch.Tensor):
inputs[k] = v.to(model.device)
elif isinstance(v, list):
inputs[k] = [x.to(model.device) for x in v]
else:
raise ValueError(f"Invalid input type: {type(v)}")
from transformers import TextIteratorStreamer
from threading import Thread
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
kwargs["streamer"] = streamer
inputs.update(kwargs)
thread = Thread(target=model.generate, kwargs=inputs)
thread.start()
for _output in streamer:
history[-1]["text"] += _output
yield history[-1]["text"], history |