mjbuehler's picture
Upload moe_phi3_v.py
cd35978 verified
raw
history blame contribute delete
No virus
16.1 kB
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import copy
from .modeling_phi3_v import Phi3VForCausalLM, Phi3MLP
from .configuration_phi3_v import Phi3VConfig
from torch.optim import Adam
from typing import Optional, Tuple
from transformers import (
PreTrainedModel,
AutoConfig,
)
import torch.nn.functional as F
import matplotlib.pyplot as plt
# Define the Gating Layer
class GatingLayer(nn.Module):
def __init__(self, input_dim, num_experts, k, layer_dtype=torch.float16):
super(GatingLayer, self).__init__()
self.num_experts = num_experts
self.k = k
self.gate = nn.Linear(input_dim, num_experts).to(dtype=layer_dtype)
def forward(self, x):
gate_scores = torch.softmax(self.gate(x), dim=-1)
topk_values, topk_indices = torch.topk(gate_scores, self.k, dim=-1)
topk_values = F.softmax(topk_values, dim=-1)
return topk_values, topk_indices
class MoE(nn.Module):
def __init__(self, input_dim, experts, gating_layer, config):
super(MoE, self).__init__()
self.experts = nn.ModuleList(experts)
self.gating_layer = gating_layer
self.output_dim = config.hidden_size
def forward(self, x):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
gate_values, gate_indices = self.gating_layer(x)
batch_size, seq_length, _ = x.size()
# Stack all expert parameters for efficient processing
expert_outputs = []
for expert in self.experts:
up_states = expert.gate_up_proj(x.view(-1, x.size(-1))) # Flatten to [batch_size * seq_length, input_dim]
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * expert.activation_fn(gate)
expert_output = expert.down_proj(up_states)
expert_outputs.append(expert_output.view(batch_size, seq_length, -1))
expert_outputs = torch.stack(expert_outputs, dim=-1) # Shape: [batch_size, seq_length, hidden_size, num_experts]
# Use torch.gather to select the expert outputs based on gate_indices
expanded_gate_indices = gate_indices.unsqueeze(-2).expand(-1, -1, self.output_dim, -1) # Shape: [batch_size, seq_length, hidden_size, k]
selected_expert_outputs = torch.gather(expert_outputs, -1, expanded_gate_indices) # Shape: [batch_size, seq_length, hidden_size, k]
# Weight the selected expert outputs by gate values
gate_values = gate_values.unsqueeze(-2) # Shape: [batch_size, seq_length, 1, k]
weighted_expert_outputs = selected_expert_outputs * gate_values # Shape: [batch_size, seq_length, hidden_size, k]
# Sum the weighted expert outputs across the k dimension
moe_output = weighted_expert_outputs.sum(dim=-1) # Shape: [batch_size, seq_length, hidden_size]
return moe_output.to(self.gating_layer.gate.weight.dtype)
# Define the ModifiedPhi3DecoderLayer Layer
class ModifiedPhi3DecoderLayer(nn.Module):
def __init__(self, original_layer, moe_layer):
super(ModifiedPhi3DecoderLayer, self).__init__()
self.self_attn = original_layer.self_attn
self.mlp = moe_layer
self.input_layernorm = original_layer.input_layernorm
self.resid_attn_dropout = original_layer.resid_attn_dropout
self.resid_mlp_dropout = original_layer.resid_mlp_dropout
self.post_attention_layernorm = original_layer.post_attention_layernorm
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
with torch.autocast(device_type="cuda", dtype=hidden_states.dtype):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
attn_output = attn_outputs[0]
hidden_states = residual + self.resid_attn_dropout(attn_output)
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.resid_mlp_dropout(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_outputs[1],)
if use_cache:
outputs += (attn_outputs[2],)
return outputs
#Define Phi3VForCausalLMMoEConfig
class Phi3VForCausalLMMoEConfig(Phi3VConfig):
model_type = "phi3_v_moe"
def __init__(self, config=None, k=1, num_expert_models=2, use_embeddings_in_router=False, **kwargs):
if config is not None:
kwargs.update(config.to_dict())
super().__init__(**kwargs)
self.k = k
self.num_expert_models = num_expert_models
self.architectures = "Phi3VForCausalLMMoE"
self.auto_map = {
"AutoConfig": "moe_phi3_v.Phi3VForCausalLMMoEConfig",
"AutoModelForCausalLM": "moe_phi3_v.Phi3VForCausalLMMoE",
}
self.use_embeddings_in_router=use_embeddings_in_router
#Define MoE Model
class Phi3VForCausalLMMoE(Phi3VForCausalLM):
config_class = Phi3VForCausalLMMoEConfig
def __init__(
self,
config,
base_model=None,
expert_models=None,
layer_dtype=torch.bfloat16,
**kwargs,
):
super().__init__(config)
self.layer_dtype = layer_dtype
self.custom_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
k = self.config.k
self.num_layers = len(base_model.model.layers) if base_model else 0
self.config.auto_map = {
"AutoConfig": "moe_phi3_v.Phi3VForCausalLMMoEConfig",
"AutoModelForCausalLM": "moe_phi3_v.Phi3VForCausalLMMoE",
}
self.use_embeddings_in_router=config.use_embeddings_in_router
print ("Use embeddigs in router: ", self.use_embeddings_in_router )
self.model = base_model or Phi3VForCausalLM(
self.config
)
if base_model and expert_models:
self.num_expert_models = len(expert_models)
self._init_moe_layers(base_model, expert_models, k, layer_dtype)
else:
print(
"Init function called and generating dummy experts: k=",
k,
"experts=",
self.config.num_expert_models,
)
num_dummy_experts = self.config.num_expert_models
self._init_moe_layers_with_dummy_experts(
self.model, k, num_dummy_experts, layer_dtype
)
self.config.model_type = "phi3_v_moe"
def _init_base_model(self):
return PreTrainedModel(self.config)
def _init_moe_layers(self, base_model, expert_models, k, layer_dtype):
self.num_layers = len(base_model.model.layers)
for i in tqdm(range(self.num_layers)):
experts = []
for expert_model in expert_models:
expert = copy.deepcopy(expert_model.model.layers[i].mlp).to(
dtype=layer_dtype
)
experts.append(expert)
gating_layer = GatingLayer(
input_dim=self.config.hidden_size,
num_experts=len(experts),
k=k,
layer_dtype=layer_dtype,
)
moe_layer = MoE(
input_dim=self.config.hidden_size,
experts=experts,
gating_layer=gating_layer,
config=self.config,
).to(dtype=layer_dtype)
self.model.model.layers[i] = ModifiedPhi3DecoderLayer(
self.model.model.layers[i], moe_layer
).to(dtype=layer_dtype)
def _init_moe_layers_with_dummy_experts(
self, base_model, k, num_dummy_experts, layer_dtype
):
self.num_layers = len(base_model.model.layers)
for i in tqdm(range(self.num_layers)):
experts = []
for _ in range(num_dummy_experts):
dummy_expert = Phi3MLP(self.config).to(dtype=layer_dtype)
experts.append(dummy_expert)
gating_layer = GatingLayer(
input_dim=self.config.hidden_size,
num_experts=len(experts),
k=k,
layer_dtype=layer_dtype,
)
moe_layer = MoE(
input_dim=self.config.hidden_size,
experts=experts,
gating_layer=gating_layer,
config=self.config,
).to(dtype=layer_dtype)
self.model.model.layers[i] = ModifiedPhi3DecoderLayer(
self.model.model.layers[i], moe_layer
).to(dtype=layer_dtype)
def forward(self, *args, **kwargs):
return self.model.forward(*args, **kwargs)
def generate(self, *args, **kwargs):
return self.model.generate(*args, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Initialize the model using the superclass method
model = super(Phi3VForCausalLMMoE, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
return model
def plot_loss_histories(self, all_loss_histories, loss_steps, filename="loss_history.svg"):
plt.figure(figsize=(12, 8))
for layer_idx, loss_history in enumerate(all_loss_histories):
plt.plot(
range(0, len(loss_history) * loss_steps, loss_steps),
loss_history,
label=f'Layer {layer_idx}',
linewidth=2, # Thicker line
marker='o' # Circle marker for each data point
)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss History per Layer, MoE Gating Network')
plt.legend()
plt.grid(True)
try:
plt.savefig(filename)
except:
print("Figure file save failed...")
plt.show()
def train_gating_layer_params_from_hidden_states(self, processor, prompts_per_expert, epochs=1000, loss_steps=100,
lr=1e-4, layer_offset=0):
self.to(self.custom_device)
self.eval()
print ('btype:', self.layer_dtype, 'device=', self.custom_device)
all_gating_layer_params = []
all_loss_histories = [] # To store loss histories for each layer
expert_hidden_states_per_layer = [[] for _ in range(self.num_layers)]
# Collect hidden states for each expert
for prompts in tqdm(prompts_per_expert, desc="Processing Prompts"):
for prompt in tqdm(prompts, desc="Processing Single Prompt", leave=False):
inputs = processor(text=prompt['text'], images=prompt['image'], return_tensors="pt").to(self.custom_device).to(self.layer_dtype)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states
for layer_idx in tqdm(range(self.num_layers)):
hidden_state = hidden_states[layer_idx+layer_offset].mean(dim=1) # Averaging over the sequence dimension
expert_hidden_states_per_layer[layer_idx].append(hidden_state)
# Train the gating layers
for layer_idx in tqdm(range(self.num_layers), desc="Training Gating Layers"):
print(f"Training gating layer parameters for layer {layer_idx}")
# Ensure we have hidden states collected for the current layer
if not expert_hidden_states_per_layer[layer_idx]:
raise ValueError(f"No hidden states collected for layer {layer_idx}")
# Aggregate hidden states for each expert and stack them
expert_hidden_states = []
num_prompts_per_expert = len(prompts_per_expert[0])
for i in range(len(prompts_per_expert)):
hidden_states_for_expert = expert_hidden_states_per_layer[layer_idx][i * num_prompts_per_expert: (i + 1) * num_prompts_per_expert]
hidden_state_avg = torch.stack(hidden_states_for_expert).mean(dim=0)
expert_hidden_states.append(hidden_state_avg)
expert_hidden_states = torch.stack(expert_hidden_states).to(self.layer_dtype)
input_dim = self.config.hidden_size
num_experts = self.config.num_expert_models
class SimpleGatingLayer(nn.Module):
def __init__(self, input_dim, num_experts, layer_dtype=torch.bfloat16):
super(SimpleGatingLayer, self).__init__()
self.gate = nn.Linear(input_dim, num_experts).to(dtype=layer_dtype)
def forward(self, x):
#return torch.softmax(self.gate(x), dim=-1)
return self.gate(x)
gating_layer = SimpleGatingLayer(self.config.hidden_size, num_experts, layer_dtype=self.layer_dtype).to(self.custom_device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(gating_layer.parameters(), lr=lr)
loss_history = []
for epoch in tqdm(range(epochs), desc=f"Training Gating Layer {layer_idx}"):
optimizer.zero_grad()
# Reshape expert_hidden_states to match (batch_size, input_dim)
expert_hidden_states_reshaped = expert_hidden_states.view(-1, input_dim)
outputs = gating_layer(expert_hidden_states_reshaped)
labels = torch.arange(num_experts).to(self.custom_device)
#print ("outputs, labels" , outputs.shape, labels.shape)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if epoch % loss_steps == 0:
loss_history.append(loss.item())
all_loss_histories.append(loss_history)
all_gating_layer_params.append(gating_layer.state_dict())
self.plot_loss_histories(all_loss_histories, loss_steps)
return all_gating_layer_params
def set_gating_layer_params(self, gating_layer_params):
for layer_idx, params in enumerate(gating_layer_params):
self.model.model.layers[layer_idx].mlp.gating_layer.load_state_dict(params)
def freeze_except_gating_layers(model):
# freeze_except_gating_layers(moe_model)
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
# Unfreeze gating layer parameters
for layer in model.model.model.layers:
for name, param in layer.mlp.gating_layer.named_parameters():
param.requires_grad = True
def un_freeze_all(model):
# freeze_except_gating_layers(moe_model)
# Freeze all parameters
for param in model.parameters():
param.requires_grad = True
from transformers import AutoConfig
AutoConfig.register("phi3_v_moe", Phi3VForCausalLMMoEConfig)
from transformers.models.auto.modeling_auto import MODEL_MAPPING
MODEL_MAPPING.update({"phi3_v_moe": Phi3VForCausalLMMoE})