How to run inference?
I have tried to use the model in the following way, but I am getting error that the amount of image tokens inside the prompt is 0. What I am doing wrong, what is the special token for an image?
Could you provide the simplest inference script?
Best
from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration
model = LlavaForConditionalGeneration.from_pretrained("SpursgoZmy/table-llava-v1.5-7b", load_in_8bit=True, device_map='auto', revision='d033ef5f1ef171e467240fc2cf9dec61960c87e8')
processor = AutoProcessor.from_pretrained("SpursgoZmy/table-llava-v1.5-7b", revision='d033ef5f1ef171e467240fc2cf9dec61960c87e8')
prompt = "<image>\nWhat's the content of the image? ASSISTANT:"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors="pt")
generate_ids = model.generate(**inputs, max_new_tokens=15)
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(outputs)
In the GitHub documentation, it mentions following llava, so I also tried testing with llava, but the results aren't coming out well. It seems I might be doing something wrong. Could you please provide an example code?
https://github.com/SpursGoZmy/Table-LLaVA/blob/main/scripts/v1_5/table_llava_scripts/table_llava_inference.sh
# Table-LLaVA/scripts/v1_5/table_llava_scripts
/table_llava_inference.sh
for IDX in $(seq 0 $((CHUNKS-1))); do
CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa \
--model-path path_to_LLM_weights/table-llava-v1.5-7b/ \
--question-file ./LLaVA-Inference/MMTab-eval_test_data_49K_llava_jsonl_format.jsonl \
--image-folder ./LLaVA-Inference/all_test_image \
--answers-file ./eval_results/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \
--num-chunks $CHUNKS \
--chunk-idx $IDX \
--temperature 0 \
--conv-mode vicuna_v1 &
done
my test code
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
model_path = "SpursgoZmy/table-llava-v1.5-7b"
disable_torch_init()
model_path = os.path.expanduser(model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device='cpu')
from jinja2 import Template
template_str = """
{% for message in messages %}
{% if message['role'] != 'system' %}
{{ message['role'].upper() + ': ' }}
{% endif %}
{# Render all images first #}
{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}
{{ '<image>\n' }}
{% endfor %}
{# Render all text next #}
{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}
{{ content['text'] + ' ' }}
{% endfor %}
{% endfor %}
{% if add_generation_prompt %}
{{ 'ASSISTANT:' }}
{% endif %}
"""
conversation = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This picture illustrates a table. Please represent this table with the markdown format in text.",
},
{"type": "image"},
],
},
]
template = Template(template_str)
prompt = template.render(messages=conversation, add_generation_prompt=True)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
image = Image.open('./img/table.png')
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
device = torch.device('cpu')
model.to(device).float()
with torch.inference_mode():
output_ids = model.generate(
input_ids.to(device),
images=image_tensor.unsqueeze(0).float().to(device),
do_sample=True,
temperature=0.1,
top_p=0.9,
# num_beams=args.num_beams,
# no_repeat_ngram_size=3,
max_new_tokens=4096,
use_cache=True)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
print(outputs)
Maybe this is because the saved Table-LLaVA checkpoints from the original LLaVA repository is not directly compatible with the Transformers, which is mentioned in this github issue. I will try the provided conversion script and upload new checkpoints. But for now, maybe the checkpoints can only be loaded locally instead of loading from HuggingFace, i.e., download Table-LLaVA checkpoints and set the 'model-path' to your local path of the model weights folder?