Several issues loading and using the model with transformers==4.39.2

#7
by csegalin - opened
class LlavaMistralCaptioner:
    def __init__(self, device='cuda', 
                 hf_model="llava-hf/llava-v1.6-mistral-7b-hf", 
                 bf16=False, 
                 quant_force=True,
                 ):
        from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig, AutoProcessor

        self.device = device
        
        if bf16:
            self.torch_type = torch.bfloat16
        else:
            self.torch_type = torch.float16

        
        with torch.cuda.device(self.device):
            _, total_bytes = torch.cuda.mem_get_info()
            total_gb = total_bytes / (1 << 30)
            if total_gb < 40:
                quant = True
            else:
                quant = False

        self.quantization_config = BitsAndBytesConfig(load_in_4bit=True,
                                                    bnb_4bit_quant_type="nf4",
                                                    bnb_4bit_compute_dtype=self.torch_type,
                                                        )
        print("========Use torch type as:{} with device:{}========\n".format(self.torch_type, self.device))

        self.model = LlavaNextForConditionalGeneration.from_pretrained(pretrained_model_name_or_path=hf_model, 
                                                                       torch_dtype=self.torch_type, 
                                                                       low_cpu_mem_usage=True,
                                                                       attn_implementation="flash_attention_2",
                                                                       quantization_config=self.quantization_config if quant or quant_force else None,
                                                                       # device_map='auto'
                                                                       ).eval()
        self.model.tie_weights()  
        # self.processor = AutoProcessor.from_pretrained(hf_model)
        self.processor = LlavaNextProcessor.from_pretrained(hf_model)

    def caption(self, image, 
                 prompt, 
                 max_tokens=225,
                 top_k=1, 
                 top_p=0.1, 
                 num_beams=1,
                 do_sample=True, 
                 temperature=0.1,
                 use_cache=True):
        import re
        prompt = f'''[INST] <image>\n {prompt} [/INST]'''
        
        inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, self.torch_type)
        outputs = self.model.generate(**inputs,
                                    max_new_tokens=max_tokens, 
                                    top_k=top_k, 
                                    top_p=top_p, 
                                    num_beams=num_beams,
                                    do_sample=True if temperature > 0 else do_sample, 
                                    temperature=temperature,
                                    use_cache=use_cache,
                                    # pad_token_id=2, 
                                    # num_return_sequences=1
                                   )
        response = self.processor.decode(outputs[0],
                                        skip_special_tokens=True,
                                         clean_up_tokenization_spaces=False)
        response = response.split('[/INST]')[-1].strip()
        response = re.sub(r'\n+', ' ', response)
        response = response.strip().replace("</s>", "").replace("<s>", "").replace("*", " ")
        return response

1 when load the model I get
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
The model weights are not tied. Please use the tie_weights method before using the infer_auto_device function.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

2 when generating the caption I get the same caption repeated 3 times

Any help on this?

Sign up or log in to comment