MMAlaya / mm_builder.py
bingwork's picture
upload model scripts
1f9001d verified
raw
history blame
1.11 kB
import torch
from modeling_mmalaya import MMAlayaMPTForCausalLM
from transformers import AutoTokenizer
from mm_utils import DEFAULT_IMAGE_TOKEN
def load_pretrained_model(model_path, device_map="auto", device="cuda"):
kwargs = {"device_map": device_map}
if device != "cuda":
kwargs['device_map'] = {"": device}
kwargs['torch_dtype'] = torch.bfloat16
print('******** load mpt model from here kwargs ', kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = MMAlayaMPTForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
vision_tower.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len