File size: 1,105 Bytes
1f9001d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from modeling_mmalaya import MMAlayaMPTForCausalLM
from transformers import AutoTokenizer
from mm_utils import DEFAULT_IMAGE_TOKEN
def load_pretrained_model(model_path, device_map="auto", device="cuda"):
kwargs = {"device_map": device_map}
if device != "cuda":
kwargs['device_map'] = {"": device}
kwargs['torch_dtype'] = torch.bfloat16
print('******** load mpt model from here kwargs ', kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = MMAlayaMPTForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
|