bingwork commited on
Commit
f4b2600
1 Parent(s): bc2260e

Delete mm_builder.py

Browse files
Files changed (1) hide show
  1. mm_builder.py +0 -33
mm_builder.py DELETED
@@ -1,33 +0,0 @@
1
- import torch
2
- from modeling_mmalaya import MMAlayaMPTForCausalLM
3
- from transformers import AutoTokenizer
4
- from mm_utils import DEFAULT_IMAGE_TOKEN
5
-
6
-
7
- def load_pretrained_model(model_path, device_map="auto", device="cuda"):
8
- kwargs = {"device_map": device_map}
9
- if device != "cuda":
10
- kwargs['device_map'] = {"": device}
11
- kwargs['torch_dtype'] = torch.bfloat16
12
-
13
- print('******** load mpt model from here kwargs ', kwargs)
14
- tokenizer = AutoTokenizer.from_pretrained(model_path)
15
- model = MMAlayaMPTForCausalLM.from_pretrained(
16
- model_path,
17
- low_cpu_mem_usage=True,
18
- **kwargs
19
- )
20
-
21
- tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True)
22
- model.resize_token_embeddings(len(tokenizer))
23
- vision_tower = model.get_vision_tower()
24
- vision_tower.to(device=device, dtype=torch.float16)
25
- image_processor = vision_tower.image_processor
26
-
27
- if hasattr(model.config, "max_sequence_length"):
28
- context_len = model.config.max_sequence_length
29
- else:
30
- context_len = 2048
31
-
32
- return tokenizer, model, image_processor, context_len
33
-