|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token |
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
|
|
from PIL import Image |
|
import requests |
|
import copy |
|
import torch |
|
|
|
import sys |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
pretrained = "lmms-lab/llava-onevision-qwen2-0.5b-si" |
|
model_name = "llava_qwen" |
|
device = "cuda" |
|
device_map = "auto" |
|
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) |
|
|
|
model.eval() |
|
|
|
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
image_tensor = process_images([image], image_processor, model.config) |
|
image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor] |
|
|
|
conv_template = "qwen_1_5" |
|
question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?" |
|
conv = copy.deepcopy(conv_templates[conv_template]) |
|
conv.append_message(conv.roles[0], question) |
|
conv.append_message(conv.roles[1], None) |
|
prompt_question = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) |
|
image_sizes = [image.size] |
|
|
|
|
|
cont = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
image_sizes=image_sizes, |
|
do_sample=False, |
|
temperature=0, |
|
max_new_tokens=4096, |
|
) |
|
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) |
|
print(text_outputs) |
|
|
|
from threading import Thread |
|
from transformers import TextIteratorStreamer |
|
import json |
|
|
|
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
image_tensor = process_images([image], image_processor, model.config) |
|
image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor] |
|
|
|
conv_template = "qwen_1_5" |
|
question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?" |
|
conv = copy.deepcopy(conv_templates[conv_template]) |
|
conv.append_message(conv.roles[0], question) |
|
conv.append_message(conv.roles[1], None) |
|
prompt_question = conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) |
|
image_sizes = [image.size] |
|
|
|
max_context_length = getattr(model.config, "max_position_embeddings", 2048) |
|
num_image_tokens = question.count(DEFAULT_IMAGE_TOKEN) * model.get_vision_tower().num_patches |
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) |
|
|
|
max_new_tokens = min(4096, max_context_length - input_ids.shape[-1] - num_image_tokens) |
|
|
|
if max_new_tokens < 1: |
|
print( |
|
json.dumps( |
|
{ |
|
"text": question + "Exceeds max token length. Please start a new conversation, thanks.", |
|
"error_code": 0, |
|
} |
|
) |
|
) |
|
else: |
|
gen_kwargs = { |
|
"do_sample": False, |
|
"temperature": 0, |
|
"max_new_tokens": max_new_tokens, |
|
"images": image_tensor, |
|
"image_sizes": image_sizes, |
|
} |
|
|
|
thread = Thread( |
|
target=model.generate, |
|
kwargs=dict( |
|
inputs=input_ids, |
|
streamer=streamer, |
|
**gen_kwargs, |
|
), |
|
) |
|
thread.start() |
|
|
|
generated_text = "" |
|
for new_text in streamer: |
|
generated_text += new_text |
|
sys.stdout.write(new_text) |
|
sys.stdout.flush() |
|
|
|
print("\nFinal output:", generated_text) |
|
|