Several issues loading and using the model with transformers==4.39.2
#7
by
csegalin
- opened
class LlavaMistralCaptioner:
def __init__(self, device='cuda',
hf_model="llava-hf/llava-v1.6-mistral-7b-hf",
bf16=False,
quant_force=True,
):
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig, AutoProcessor
self.device = device
if bf16:
self.torch_type = torch.bfloat16
else:
self.torch_type = torch.float16
with torch.cuda.device(self.device):
_, total_bytes = torch.cuda.mem_get_info()
total_gb = total_bytes / (1 << 30)
if total_gb < 40:
quant = True
else:
quant = False
self.quantization_config = BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=self.torch_type,
)
print("========Use torch type as:{} with device:{}========\n".format(self.torch_type, self.device))
self.model = LlavaNextForConditionalGeneration.from_pretrained(pretrained_model_name_or_path=hf_model,
torch_dtype=self.torch_type,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2",
quantization_config=self.quantization_config if quant or quant_force else None,
# device_map='auto'
).eval()
self.model.tie_weights()
# self.processor = AutoProcessor.from_pretrained(hf_model)
self.processor = LlavaNextProcessor.from_pretrained(hf_model)
def caption(self, image,
prompt,
max_tokens=225,
top_k=1,
top_p=0.1,
num_beams=1,
do_sample=True,
temperature=0.1,
use_cache=True):
import re
prompt = f'''[INST] <image>\n {prompt} [/INST]'''
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, self.torch_type)
outputs = self.model.generate(**inputs,
max_new_tokens=max_tokens,
top_k=top_k,
top_p=top_p,
num_beams=num_beams,
do_sample=True if temperature > 0 else do_sample,
temperature=temperature,
use_cache=use_cache,
# pad_token_id=2,
# num_return_sequences=1
)
response = self.processor.decode(outputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
response = response.split('[/INST]')[-1].strip()
response = re.sub(r'\n+', ' ', response)
response = response.strip().replace("</s>", "").replace("<s>", "").replace("*", " ")
return response
1 when load the model I get
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
The model weights are not tied. Please use the tie_weights
method before using the infer_auto_device
function.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2 when generating the caption I get the same caption repeated 3 times
Any help on this?