import torch from modeling_mmalaya import MMAlayaMPTForCausalLM from transformers import AutoTokenizer from mm_utils import DEFAULT_IMAGE_TOKEN def load_pretrained_model(model_path, device_map="auto", device="cuda"): kwargs = {"device_map": device_map} if device != "cuda": kwargs['device_map'] = {"": device} kwargs['torch_dtype'] = torch.bfloat16 print('******** load mpt model from here kwargs ', kwargs) tokenizer = AutoTokenizer.from_pretrained(model_path) model = MMAlayaMPTForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() vision_tower.to(device=device, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): context_len = model.config.max_sequence_length else: context_len = 2048 return tokenizer, model, image_processor, context_len