| from omegaconf import DictConfig |
| import torch |
| import torch.nn as nn |
| from typing import List |
|
|
| from barista.data.metadata import Metadata |
| from barista.models.tokenizer import Tokenizer |
| from barista.models.transformer import Transformer |
|
|
|
|
| class Barista(nn.Module): |
| def __init__(self, model_config: DictConfig, metadata: Metadata, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.metadata = metadata |
| |
| self.tokenizer = Tokenizer( |
| config=model_config.tokenizer, |
| metadata=self.metadata, |
| ) |
| |
| self.backbone = Transformer( |
| **model_config.backbone, |
| ) |
| |
| self.d_hidden = model_config.backbone.d_hidden |
| |
| self.head = None |
| |
| def create_downstream_head(self, n_chans, output_dim): |
| self.channel_weights = nn.Linear( |
| n_chans * self.tokenizer.num_subsegments, |
| 1, |
| bias=False, |
| ) |
| self.binary_classifier = nn.Linear( |
| self.d_hidden, output_dim |
| ) |
| |
| def get_latent_embeddings(self, x: torch.Tensor, subject_sessions: List): |
| |
| tokenized_x = self.tokenizer(x, subject_sessions, output_as_list=False) |
| |
| |
| latents = self.backbone( |
| x=tokenized_x.tokens, |
| seq_lens=tokenized_x.seq_lens, |
| position_ids=tokenized_x.position_ids, |
| ) |
| |
| return latents |
|
|
| def forward(self, x: torch.Tensor, subject_sessions: List): |
| |
| latents = self.get_latent_embeddings(x, subject_sessions) |
| |
| |
| batch_size = x[0].shape[0] |
| latents_reshaped = latents.reshape(batch_size, -1, latents.shape[-1]) |
| x = self.channel_weights(latents_reshaped.permute(0, 2, 1)).squeeze(dim=-1) |
| x = self.binary_classifier(x) |
| |
| return x |
|
|
| def get_task_params(self): |
| return [*self.channel_weights.named_parameters(), *self.binary_classifier.named_parameters()] |
| |
| def get_upstream_params(self): |
| return [*self.tokenizer.named_parameters(), *self.backbone.named_parameters()] |
|
|