v1
Browse files- app.py +2 -2
- trol/arch_internlm2/modeling_internlm2.py +5 -3
- trol/arch_phi3/modeling_phi3.py +5 -3
- trol/load_trol.py +7 -1
app.py
CHANGED
@@ -18,8 +18,8 @@ from transformers import TextIteratorStreamer
|
|
18 |
from torchvision.transforms.functional import pil_to_tensor
|
19 |
|
20 |
# flash attention
|
21 |
-
|
22 |
-
|
23 |
|
24 |
# accel
|
25 |
accel = Accelerator()
|
|
|
18 |
from torchvision.transforms.functional import pil_to_tensor
|
19 |
|
20 |
# flash attention
|
21 |
+
import subprocess
|
22 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
23 |
|
24 |
# accel
|
25 |
accel = Accelerator()
|
trol/arch_internlm2/modeling_internlm2.py
CHANGED
@@ -867,13 +867,15 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
867 |
self.norm = InternLM2RMSNorm(
|
868 |
config.hidden_size, eps=config.rms_norm_eps)
|
869 |
|
870 |
-
self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1)]*self.config.num_hidden_layers)
|
871 |
-
self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
|
872 |
-
|
873 |
self.gradient_checkpointing = False
|
874 |
# Initialize weights and apply final processing
|
875 |
self.post_init()
|
876 |
|
|
|
|
|
|
|
|
|
|
|
877 |
def get_input_embeddings(self):
|
878 |
return self.tok_embeddings
|
879 |
|
|
|
867 |
self.norm = InternLM2RMSNorm(
|
868 |
config.hidden_size, eps=config.rms_norm_eps)
|
869 |
|
|
|
|
|
|
|
870 |
self.gradient_checkpointing = False
|
871 |
# Initialize weights and apply final processing
|
872 |
self.post_init()
|
873 |
|
874 |
+
def initialize_trol_gating(self):
|
875 |
+
self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1).cuda()]*self.config.num_hidden_layers)
|
876 |
+
self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
|
877 |
+
|
878 |
+
|
879 |
def get_input_embeddings(self):
|
880 |
return self.tok_embeddings
|
881 |
|
trol/arch_phi3/modeling_phi3.py
CHANGED
@@ -1031,13 +1031,15 @@ class Phi3Model(Phi3PreTrainedModel):
|
|
1031 |
self._attn_implementation = "flash_attention_2"
|
1032 |
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1033 |
|
1034 |
-
self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1)]*self.config.num_hidden_layers)
|
1035 |
-
self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
|
1036 |
-
|
1037 |
self.gradient_checkpointing = False
|
1038 |
# Initialize weights and apply final processing
|
1039 |
self.post_init()
|
1040 |
|
|
|
|
|
|
|
|
|
|
|
1041 |
def get_input_embeddings(self):
|
1042 |
return self.embed_tokens
|
1043 |
|
|
|
1031 |
self._attn_implementation = "flash_attention_2"
|
1032 |
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1033 |
|
|
|
|
|
|
|
1034 |
self.gradient_checkpointing = False
|
1035 |
# Initialize weights and apply final processing
|
1036 |
self.post_init()
|
1037 |
|
1038 |
+
def initialize_trol_gating(self):
|
1039 |
+
self.trol_gating = nn.ModuleList([nn.Linear(self.config.hidden_size, 1).cuda()]*self.config.num_hidden_layers)
|
1040 |
+
self.trol_function = lambda x, idx: 0.5*F.tanh(self.trol_gating[idx](x))+0.5
|
1041 |
+
|
1042 |
+
|
1043 |
def get_input_embeddings(self):
|
1044 |
return self.embed_tokens
|
1045 |
|
trol/load_trol.py
CHANGED
@@ -81,11 +81,17 @@ def load_trol(link):
|
|
81 |
# setting config
|
82 |
setting_trol_config(trol, tok_trol, image_special_token)
|
83 |
|
84 |
-
|
85 |
# trol gating load
|
86 |
from huggingface_hub import hf_hub_download
|
87 |
try:
|
|
|
88 |
trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
89 |
except:
|
|
|
90 |
trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
|
|
|
|
|
|
|
|
|
|
91 |
return trol, tok_trol
|
|
|
81 |
# setting config
|
82 |
setting_trol_config(trol, tok_trol, image_special_token)
|
83 |
|
|
|
84 |
# trol gating load
|
85 |
from huggingface_hub import hf_hub_download
|
86 |
try:
|
87 |
+
trol.model.initialize_trol_gating()
|
88 |
trol.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
89 |
except:
|
90 |
+
trol.language_model.model.initialize_trol_gating()
|
91 |
trol.language_model.model.trol_gating.load_state_dict(torch.load(hf_hub_download(repo_id=path, filename="trol_gating.pt")))
|
92 |
+
|
93 |
+
# X -> float16 conversion
|
94 |
+
for param in trol.parameters():
|
95 |
+
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
|
96 |
+
param.data = param.data.to(torch.float16)
|
97 |
return trol, tok_trol
|