|
import torch |
|
from modeling_mmalaya import MMAlayaMPTForCausalLM |
|
from transformers import AutoTokenizer |
|
from mm_utils import DEFAULT_IMAGE_TOKEN |
|
|
|
|
|
def load_pretrained_model(model_path, device_map="auto", device="cuda"): |
|
kwargs = {"device_map": device_map} |
|
if device != "cuda": |
|
kwargs['device_map'] = {"": device} |
|
kwargs['torch_dtype'] = torch.bfloat16 |
|
|
|
print('******** load mpt model from here kwargs ', kwargs) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = MMAlayaMPTForCausalLM.from_pretrained( |
|
model_path, |
|
low_cpu_mem_usage=True, |
|
**kwargs |
|
) |
|
|
|
tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
vision_tower = model.get_vision_tower() |
|
vision_tower.to(device=device, dtype=torch.float16) |
|
image_processor = vision_tower.image_processor |
|
|
|
if hasattr(model.config, "max_sequence_length"): |
|
context_len = model.config.max_sequence_length |
|
else: |
|
context_len = 2048 |
|
|
|
return tokenizer, model, image_processor, context_len |
|
|
|
|