DiffuSynthV0.2 / model /multimodal_model.py
WeixuanYuan's picture
Upload 66 files
ae1bdf7 verified
import itertools
import json
import random
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from tools import create_key
from model.timbre_encoder_pretrain import get_timbre_encoder
class ProjectionLayer(nn.Module):
"""Single-layer Linear projection with dropout, layer norm, and Gelu activation"""
def __init__(self, input_dim, output_dim, dropout):
super(ProjectionLayer, self).__init__()
self.projection = nn.Linear(input_dim, output_dim)
self.gelu = nn.GELU()
self.fc = nn.Linear(output_dim, output_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(output_dim)
def forward(self, x):
projected = self.projection(x)
x = self.gelu(projected)
x = self.fc(x)
x = self.dropout(x)
x = x + projected
x = self.layer_norm(x)
return x
class ProjectionHead(nn.Module):
"""Stack of 'ProjectionLayer'"""
def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2):
super(ProjectionHead, self).__init__()
self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim,
projection_dim,
dropout) for i in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class multi_modal_model(nn.Module):
"""The multi-modal model for contrastive learning"""
def __init__(
self,
timbre_encoder,
text_encoder,
spectrogram_feature_dim,
text_feature_dim,
multi_modal_emb_dim,
temperature,
dropout,
num_projection_layers=1,
freeze_spectrogram_encoder=True,
freeze_text_encoder=True,
):
super().__init__()
self.timbre_encoder = timbre_encoder
self.text_encoder = text_encoder
self.multi_modal_emb_dim = multi_modal_emb_dim
self.text_projection = ProjectionHead(embedding_dim=text_feature_dim,
projection_dim=self.multi_modal_emb_dim, dropout=dropout,
num_layers=num_projection_layers)
self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim,
projection_dim=self.multi_modal_emb_dim, dropout=dropout,
num_layers=num_projection_layers)
self.temperature = temperature
# Make spectrogram_encoder parameters non-trainable
for param in self.timbre_encoder.parameters():
param.requires_grad = not freeze_spectrogram_encoder
# Make text_encoder parameters non-trainable
for param in self.text_encoder.parameters():
param.requires_grad = not freeze_text_encoder
def forward(self, spectrogram_batch, tokenized_text_batch):
# Getting Image and Text Embeddings (with same dimension)
spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
text_features = self.text_encoder.get_text_features(**tokenized_text_batch)
# Concat and apply projection
spectrogram_embeddings = self.spectrogram_projection(spectrogram_features)
text_embeddings = self.text_projection(text_features)
# Calculating the Loss
logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature
images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax(
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
contrastive_loss = contrastive_loss.mean()
return contrastive_loss
def get_text_features(self, input_ids, attention_mask):
text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
return self.text_projection(text_features)
def get_timbre_features(self, spectrogram_batch):
spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch)
return self.spectrogram_projection(spectrogram_features)
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"):
mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config)
print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}")
mmm.to(device)
if load_pretrain:
print(f"Loading weights from models/{model_name}_MMM.pth")
checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device)
mmm.load_state_dict(checkpoint['model_state_dict'])
mmm.eval()
return mmm
def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device):
(data, attributes) = next(iter(train_loader))
keys = [create_key(attribute) for attribute in attributes]
while(len(set(keys)) != len(keys)):
(data, attributes) = next(iter(train_loader))
keys = [create_key(attribute) for attribute in attributes]
data = data.to(device)
texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
loss = model(data, tokenized_text)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device):
(data, attributes) = next(iter(valid_loader))
keys = [create_key(attribute) for attribute in attributes]
while(len(set(keys)) != len(keys)):
(data, attributes) = next(iter(valid_loader))
keys = [create_key(attribute) for attribute in attributes]
data = data.to(device)
texts = [labels_mapping[create_key(attribute)] for attribute in attributes]
selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts]
tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device)
loss = model(data, tokenized_text)
return loss.item()
def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder,
timbre_encoder_Config, MMM_config, MMM_training_config,
mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True,
timbre_encoder_name=None, init_loss=None, save_steps=2000):
def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter,
current_loss):
model_hyperparameter = MMM_config
model_hyperparameter.update(MMM_training_config)
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE
model_hyperparameter["model_size"] = model_size
model_hyperparameter["current_iter"] = current_iter
model_hyperparameter["current_loss"] = current_loss
with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file:
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4)
timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name,
device=device)
mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device)
print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}")
print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}")
print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}")
print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}")
total_parameters = sum(p.numel() for p in mmm.parameters())
trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad)
print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}")
params = [
{"params": itertools.chain(
mmm.spectrogram_projection.parameters(),
mmm.text_projection.parameters(),
), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]},
]
if not MMM_config["freeze_text_encoder"]:
params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"],
"weight_decay": MMM_training_config["text_encoder_weight_decay"]})
if not MMM_config["freeze_spectrogram_encoder"]:
params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"],
"weight_decay": MMM_training_config["timbre_encoder_weight_decay"]})
optimizer = torch.optim.AdamW(params, weight_decay=0.)
if load_pretrain:
print(f"Loading weights from models/{mmm_name}_MMM.pt")
checkpoint = torch.load(f'models/{mmm_name}_MMM.pth')
mmm.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
print("Model initialized.")
if max_iter == 0:
print("Return model directly.")
return mmm, optimizer
if init_loss is None:
previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device)
else:
previous_lowest_loss = init_loss
print(f"Initial total loss: {previous_lowest_loss}")
train_loss_list = []
for i in range(max_iter):
mmm.train()
train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device)
train_loss_list.append(train_loss)
step = int(
optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy())
if (i + 1) % 100 == 0:
print('%d step' % (step))
if (i + 1) % save_steps == 0:
current_loss = np.mean(train_loss_list[-save_steps:])
print(f"train_total_loss: {current_loss}")
if current_loss < previous_lowest_loss:
previous_lowest_loss = current_loss
torch.save({
'model_state_dict': mmm.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'models/{mmm_name}_MMM.pth')
save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step,
current_loss)
return mmm, optimizer