import torch import torch.nn as nn import torch.nn.functional as F # Definição de uma camada de embedding com atenção esparsa para texto class SparseTextEmbedding(nn.Module): """ Camada de embedding para texto com atenção multi-cabeça. Realiza embeddings de tokens de texto e aplica atenção multi-cabeça. """ def __init__(self, num_tokens, emb_dim): super().__init__() self.embedding = nn.Embedding(num_tokens, emb_dim) self.attention = nn.MultiheadAttention(emb_dim, num_heads=8, batch_first=True) def forward(self, x): x = self.embedding(x) x, _ = self.attention(x, x, x) return x # Processador genérico para transformar entradas numéricas em embeddings class GenericProcessor(nn.Module): """ Processador genérico que transforma entradas numéricas em embeddings. Utiliza uma camada linear seguida por uma ativação ReLU. """ def __init__(self, input_dim, emb_dim): super().__init__() self.fc = nn.Linear(input_dim, emb_dim) def forward(self, x): return F.relu(self.fc(x)) # Especialista em transformador para domínios específicos class TransformerExpert(nn.Module): """ Especialista em domínio específico usando um encoder Transformer. Projetado para processar embeddings e realizar tarefas específicas de domínio. """ def __init__(self, emb_dim, num_heads, num_layers, ff_dim): super().__init__() transformer_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim) self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers) def forward(self, x): return self.transformer_encoder(x) # Decodificador Transformer com atenção cruzada class TransformerDecoderWithCrossAttention(nn.Module): """ Decodificador Transformer com atenção cruzada. Combina informações de múltiplas fontes e projeta o resultado final. """ def __init__(self, emb_dim, num_heads, num_layers, ff_dim): super().__init__() transformer_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=ff_dim) self.transformer_decoder = nn.TransformerDecoder(transformer_layer, num_layers=num_layers) self.projection = nn.Linear(emb_dim, emb_dim) def forward(self, x, memory): output = self.transformer_decoder(x, memory) return self.projection(output) # Modelo principal que incorpora os componentes acima class EnedinaModel(nn.Module): """ Modelo principal: Enedina. Integra diferentes componentes especializados para processar múltiplos tipos de entrada. """ def __init__(self, text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim, emb_dim=1024, num_heads=16, num_layers=12, ff_dim=4096): super().__init__() self.text_embedding = SparseTextEmbedding(text_num_tokens, emb_dim) self.image_processor = GenericProcessor(image_input_dim, emb_dim) self.equation_processor = GenericProcessor(equation_input_dim, emb_dim) self.diagram_processor = GenericProcessor(diagram_input_dim, emb_dim) self.experts = nn.ModuleList([ TransformerExpert(emb_dim, num_heads, num_layers, ff_dim) for _ in range(4) ]) self.gate = nn.Linear(emb_dim * 4, 4) self.transformer_decoder = TransformerDecoderWithCrossAttention(emb_dim, num_heads, num_layers, ff_dim) def forward(self, text_input, image_input, equation_input, diagram_input): text_emb = self.text_embedding(text_input) image_emb = self.image_processor(image_input).unsqueeze(1) equation_emb = self.equation_processor(equation_input).unsqueeze(1) diagram_emb = self.diagram_processor(diagram_input).unsqueeze(1) # Estrutura dos especialistas expert_inputs = [equation_emb, image_emb, diagram_emb, text_emb] expert_outputs = [] for i, expert in enumerate(self.experts): expert_output = expert(expert_inputs[i].permute(1, 0, 2)) expert_outputs.append(expert_output.permute(1, 0, 2)[:, -1, :]) # Combina as saídas dos especialistas combined_expert_outputs = torch.cat(expert_outputs, dim=-1) # Calcula os pesos do gate e aplica a combinação ponderada das saídas dos especialistas gate_weights = F.softmax(self.gate(combined_expert_outputs), dim=-1) expert_outputs_stack = torch.stack(expert_outputs, dim=1) combined_output = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs_stack, dim=1) # Ajustes de dimensão antes do TransformerDecoder combined_output = combined_output.unsqueeze(0) text_emb = text_emb.permute(1, 0, 2) # Aplica o decodificador Transformer com atenção cruzada output = self.transformer_decoder(text_emb, combined_output) return output # Configuração dos parâmetros do modelo e simulação de entrada para testes text_num_tokens = 200000 image_input_dim = 2048 equation_input_dim = 1024 diagram_input_dim = 1024 batch_size = 4 text_seq_len = 1000 image_seq_len = 10 equation_seq_len = 5 diagram_seq_len = 5 # Inicializa o modelo model = EnedinaModel(text_num_tokens, image_input_dim, equation_input_dim, diagram_input_dim) # Gera entradas simuladas text_input = torch.randint(0, text_num_tokens, (batch_size, text_seq_len)) image_input = torch.randn(batch_size, image_input_dim) equation_input = torch.randn(batch_size, equation_input_dim) diagram_input = torch.randn(batch_size, diagram_input_dim) # Executa o modelo com as entradas simuladas output = model(text_input, image_input, equation_input, diagram_input) # Verifica a forma da saída print("A forma de saída do tensor é:", output.shape)