File size: 1,105 Bytes
1f9001d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from modeling_mmalaya import MMAlayaMPTForCausalLM
from transformers import AutoTokenizer
from mm_utils import DEFAULT_IMAGE_TOKEN


def load_pretrained_model(model_path, device_map="auto", device="cuda"):
    kwargs = {"device_map": device_map}
    if device != "cuda":
        kwargs['device_map'] = {"": device}
    kwargs['torch_dtype'] = torch.bfloat16

    print('******** load mpt model from here kwargs ', kwargs)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = MMAlayaMPTForCausalLM.from_pretrained(
        model_path,
        low_cpu_mem_usage=True, 
        **kwargs
        )
    
    tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True)
    model.resize_token_embeddings(len(tokenizer))
    vision_tower = model.get_vision_tower()
    vision_tower.to(device=device, dtype=torch.float16)
    image_processor = vision_tower.image_processor

    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    return tokenizer, model, image_processor, context_len