File size: 3,301 Bytes
dd4cd4b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
import warnings
from utils.utils import *
from config import *
from transformers import AutoTokenizer
from transformers import BitsAndBytesConfig
warnings.filterwarnings(action='ignore')
def load_model(size):
"""
model selection
"""
# Phantom Bit
bit_quant_skip = ["linear_q", "linear_k", "linear_v", "linear_o", "gating_phantom_1", "gating_phantom_2"]
# Vision target modules
if size == '7b':
from .arch_7b.modeling_phantom import PhantomForCausalLM
from .arch_7b.tokenization_internlm2 import InternLM2Tokenizer as PhantomTokenizer
path = MODEL_7B
bit_quant_skip += ["mlp1", "wqkv", "output"]
# Loading tokenizer
tokenizer = PhantomTokenizer.from_pretrained(path, padding_side='left')
# bits
bits = 8
elif size == '3.8b':
from .arch_3_8b.modeling_phantom import PhantomForCausalLM
path = MODEL_3_8B
bit_quant_skip += ["mlp1", "qkv_proj", "phantom", "lm_head"]
# Loading tokenizer
tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
# bits
bits = 8
elif size == '1.8b':
from .arch_1_8b.modeling_phantom import PhantomForCausalLM
from .arch_1_8b.tokenization_internlm2 import InternLM2Tokenizer as PhantomTokenizer
path = MODEL_1_8B
bit_quant_skip += ["mlp1", "wqkv", "phantom", "output"]
# Loading tokenizer
tokenizer = PhantomTokenizer.from_pretrained(path, padding_side='left')
# bits
bits = 8
elif size == '0.5b':
from .arch_0_5b.modeling_phantom import PhantomForCausalLM
path = MODEL_0_5B
bit_quant_skip += ["mlp1", "q_proj", "k_proj", "v_proj", "phantom", "lm_head"]
# Loading tokenizer
tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
# bits
bits = 8
else:
raise Exception("Unsupported Size")
# huggingface model configuration
huggingface_config = {}
# Bit quantization
if bits in [4, 8]:
huggingface_config.update(dict(
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2",
quantization_config=BitsAndBytesConfig(
load_in_4bit=bits == 4,
load_in_8bit=bits == 8,
llm_int8_skip_modules=bit_quant_skip,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
))
else:
huggingface_config.update(dict(
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2",
))
# Model Uploading
model = PhantomForCausalLM.from_pretrained(path, **huggingface_config)
# Parameter arrangement
freeze_model(model)
model.eval()
# bfloat16/float16 conversion
for param in model.parameters():
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
param.data = param.data.to(torch.bfloat16)
return model, tokenizer |