File size: 2,446 Bytes
f65854a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModelForImageClassification
from .configuration_moe import MoEConfig
def subgate(num_out):
layers = nn.Sequential(
nn.Flatten(),
nn.Linear(224 * 224 * 3, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, num_out),
)
return layers
class MoEModelForImageClassification(PreTrainedModel):
config_class = MoEConfig
def __init__(self, config):
super().__init__(config)
self.num_classes = config.num_classes
self.switch_gate_model = AutoModelForImageClassification.from_pretrained(
config.switch_gate
)
self.baseline_model = AutoModelForImageClassification.from_pretrained(
config.baseline_model
)
self.expert_model_1 = AutoModelForImageClassification.from_pretrained(
config.experts[0]
)
self.expert_model_2 = AutoModelForImageClassification.from_pretrained(
config.experts[1]
)
self.subgate = subgate(2)
# Freeze all params
for module in [
self.switch_gate_model,
self.baseline_model,
self.expert_model_1,
self.expert_model_2,
]:
for param in module.parameters():
param.requires_grad = False
def forward(self, pixel_values, labels=None):
switch_gate_result = self.switch_gate_model(pixel_values).logits
expert1_result = self.expert_model_1(pixel_values).logits
expert2_result = self.expert_model_2(pixel_values).logits
# Gating Network
experts_result = torch.stack(
[expert1_result, expert2_result], dim=1
) * switch_gate_result.unsqueeze(-1)
experts_result = experts_result.sum(dim=1)
baseline_model_result = self.baseline_model(pixel_values).logits
subgate_result = self.subgate(pixel_values)
subgate_prob = F.softmax(subgate_result, dim=-1)
experts_and_base_result = torch.stack(
[experts_result, baseline_model_result], dim=1
) * subgate_prob.unsqueeze(-1)
logits = experts_and_base_result.sum(dim=1)
if labels is not None:
loss = F.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
|