Spaces:
Sleeping
Sleeping
import itertools | |
import json | |
import random | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from tools import create_key | |
from model.timbre_encoder_pretrain import get_timbre_encoder | |
class ProjectionLayer(nn.Module): | |
"""Single-layer Linear projection with dropout, layer norm, and Gelu activation""" | |
def __init__(self, input_dim, output_dim, dropout): | |
super(ProjectionLayer, self).__init__() | |
self.projection = nn.Linear(input_dim, output_dim) | |
self.gelu = nn.GELU() | |
self.fc = nn.Linear(output_dim, output_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(output_dim) | |
def forward(self, x): | |
projected = self.projection(x) | |
x = self.gelu(projected) | |
x = self.fc(x) | |
x = self.dropout(x) | |
x = x + projected | |
x = self.layer_norm(x) | |
return x | |
class ProjectionHead(nn.Module): | |
"""Stack of 'ProjectionLayer'""" | |
def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2): | |
super(ProjectionHead, self).__init__() | |
self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim, | |
projection_dim, | |
dropout) for i in range(num_layers)]) | |
def forward(self, x): | |
for layer in self.layers: | |
x = layer(x) | |
return x | |
class multi_modal_model(nn.Module): | |
"""The multi-modal model for contrastive learning""" | |
def __init__( | |
self, | |
timbre_encoder, | |
text_encoder, | |
spectrogram_feature_dim, | |
text_feature_dim, | |
multi_modal_emb_dim, | |
temperature, | |
dropout, | |
num_projection_layers=1, | |
freeze_spectrogram_encoder=True, | |
freeze_text_encoder=True, | |
): | |
super().__init__() | |
self.timbre_encoder = timbre_encoder | |
self.text_encoder = text_encoder | |
self.multi_modal_emb_dim = multi_modal_emb_dim | |
self.text_projection = ProjectionHead(embedding_dim=text_feature_dim, | |
projection_dim=self.multi_modal_emb_dim, dropout=dropout, | |
num_layers=num_projection_layers) | |
self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim, | |
projection_dim=self.multi_modal_emb_dim, dropout=dropout, | |
num_layers=num_projection_layers) | |
self.temperature = temperature | |
# Make spectrogram_encoder parameters non-trainable | |
for param in self.timbre_encoder.parameters(): | |
param.requires_grad = not freeze_spectrogram_encoder | |
# Make text_encoder parameters non-trainable | |
for param in self.text_encoder.parameters(): | |
param.requires_grad = not freeze_text_encoder | |
def forward(self, spectrogram_batch, tokenized_text_batch): | |
# Getting Image and Text Embeddings (with same dimension) | |
spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch) | |
text_features = self.text_encoder.get_text_features(**tokenized_text_batch) | |
# Concat and apply projection | |
spectrogram_embeddings = self.spectrogram_projection(spectrogram_features) | |
text_embeddings = self.text_projection(text_features) | |
# Calculating the Loss | |
logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature | |
images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T | |
texts_similarity = text_embeddings @ text_embeddings.T | |
targets = F.softmax( | |
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 | |
) | |
texts_loss = cross_entropy(logits, targets, reduction='none') | |
images_loss = cross_entropy(logits.T, targets.T, reduction='none') | |
contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size) | |
contrastive_loss = contrastive_loss.mean() | |
return contrastive_loss | |
def get_text_features(self, input_ids, attention_mask): | |
text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask) | |
return self.text_projection(text_features) | |
def get_timbre_features(self, spectrogram_batch): | |
spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch) | |
return self.spectrogram_projection(spectrogram_features) | |
def cross_entropy(preds, targets, reduction='none'): | |
log_softmax = nn.LogSoftmax(dim=-1) | |
loss = (-targets * log_softmax(preds)).sum(1) | |
if reduction == "none": | |
return loss | |
elif reduction == "mean": | |
return loss.mean() | |
def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"): | |
mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config) | |
print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}") | |
mmm.to(device) | |
if load_pretrain: | |
print(f"Loading weights from models/{model_name}_MMM.pth") | |
checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device) | |
mmm.load_state_dict(checkpoint['model_state_dict']) | |
mmm.eval() | |
return mmm | |
def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device): | |
(data, attributes) = next(iter(train_loader)) | |
keys = [create_key(attribute) for attribute in attributes] | |
while(len(set(keys)) != len(keys)): | |
(data, attributes) = next(iter(train_loader)) | |
keys = [create_key(attribute) for attribute in attributes] | |
data = data.to(device) | |
texts = [labels_mapping[create_key(attribute)] for attribute in attributes] | |
selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts] | |
tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device) | |
loss = model(data, tokenized_text) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
return loss.item() | |
def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device): | |
(data, attributes) = next(iter(valid_loader)) | |
keys = [create_key(attribute) for attribute in attributes] | |
while(len(set(keys)) != len(keys)): | |
(data, attributes) = next(iter(valid_loader)) | |
keys = [create_key(attribute) for attribute in attributes] | |
data = data.to(device) | |
texts = [labels_mapping[create_key(attribute)] for attribute in attributes] | |
selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts] | |
tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device) | |
loss = model(data, tokenized_text) | |
return loss.item() | |
def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder, | |
timbre_encoder_Config, MMM_config, MMM_training_config, | |
mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True, | |
timbre_encoder_name=None, init_loss=None, save_steps=2000): | |
def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter, | |
current_loss): | |
model_hyperparameter = MMM_config | |
model_hyperparameter.update(MMM_training_config) | |
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE | |
model_hyperparameter["model_size"] = model_size | |
model_hyperparameter["current_iter"] = current_iter | |
model_hyperparameter["current_loss"] = current_loss | |
with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file: | |
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) | |
timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, | |
device=device) | |
mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device) | |
print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}") | |
print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}") | |
print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}") | |
print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}") | |
total_parameters = sum(p.numel() for p in mmm.parameters()) | |
trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad) | |
print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}") | |
params = [ | |
{"params": itertools.chain( | |
mmm.spectrogram_projection.parameters(), | |
mmm.text_projection.parameters(), | |
), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]}, | |
] | |
if not MMM_config["freeze_text_encoder"]: | |
params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"], | |
"weight_decay": MMM_training_config["text_encoder_weight_decay"]}) | |
if not MMM_config["freeze_spectrogram_encoder"]: | |
params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"], | |
"weight_decay": MMM_training_config["timbre_encoder_weight_decay"]}) | |
optimizer = torch.optim.AdamW(params, weight_decay=0.) | |
if load_pretrain: | |
print(f"Loading weights from models/{mmm_name}_MMM.pt") | |
checkpoint = torch.load(f'models/{mmm_name}_MMM.pth') | |
mmm.load_state_dict(checkpoint['model_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
else: | |
print("Model initialized.") | |
if max_iter == 0: | |
print("Return model directly.") | |
return mmm, optimizer | |
if init_loss is None: | |
previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device) | |
else: | |
previous_lowest_loss = init_loss | |
print(f"Initial total loss: {previous_lowest_loss}") | |
train_loss_list = [] | |
for i in range(max_iter): | |
mmm.train() | |
train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device) | |
train_loss_list.append(train_loss) | |
step = int( | |
optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy()) | |
if (i + 1) % 100 == 0: | |
print('%d step' % (step)) | |
if (i + 1) % save_steps == 0: | |
current_loss = np.mean(train_loss_list[-save_steps:]) | |
print(f"train_total_loss: {current_loss}") | |
if current_loss < previous_lowest_loss: | |
previous_lowest_loss = current_loss | |
torch.save({ | |
'model_state_dict': mmm.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
}, f'models/{mmm_name}_MMM.pth') | |
save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step, | |
current_loss) | |
return mmm, optimizer |