Spaces:
Sleeping
Sleeping
# os.environ['CUDA_VISIBLE_DEVICES'] = "0" # in case you are using a multi GPU workstation, choose your GPU here | |
import os | |
import torch | |
import random | |
import torch.nn as nn | |
from pathlib import Path | |
from tqdm import tqdm | |
from accelerate import Accelerator | |
from typing import Literal, Callable, Optional, Union | |
from waifuset.utils import log_utils | |
from waifuset.classes import Dataset, ImageInfo | |
from . import mlp, utils, train_utils | |
StrPath = Union[str, Path] | |
def train( | |
dataset_source, | |
save_path, | |
resume_path: StrPath = None, | |
data_preprocessor: Optional[Callable[[ImageInfo], float]] = None, | |
rating_func_type: Union[Callable[[ImageInfo], float], Literal['direct', 'label', 'quality']] = 'quality', | |
num_train_epochs=50, | |
learning_rate=1e-3, | |
train_batch_size=256, | |
shuffle=True, | |
flip_aug=True, | |
val_batch_size=512, | |
val_every_n_epochs=1, | |
val_percentage=0.05, # 5% of the training data will be used for validation | |
save_best_model=True, | |
clip_batch_size=1, | |
cache_to_disk: bool = False, | |
cache_path: StrPath = None, | |
mixed_precision=None, | |
max_data_loader_n_workers: int = 4, | |
persistent_workers=False, | |
mlp_model_type: Literal['default', 'large'] = 'default', | |
clip_model_name: str = "ViT-L/14", | |
input_size: int = 768, | |
batch_norm: bool = True, | |
): | |
r""" | |
:param dataset_source: any dataset source, e.g. path to the dataset. | |
:param save_path: path to save the trained model. | |
:param resume_path: path to the model to resume from. | |
:param cache_to_disk: whether to cache the training data to disk. | |
:param cache_path: path to the cached training data. If not exists, will be created from `dataset_source`. If exists, will be loaded from disk. | |
:param num_train_epochs: number of training epochs. | |
:param learning_rate: learning rate. | |
:param train_batch_size: training batch size. | |
:param val_batch_size: validation batch size. | |
:param val_every_n_epochs: validation frequency. | |
:param val_percentage: percentage of the training data to be used for validation. | |
:param encoder_batch_size: batch size for encoding images. | |
:param mixed_precision: whether to use mixed precision training. | |
:param max_data_loader_n_workers: maximum number of workers for data loaders. | |
:param persistent_workers: whether to use persistent workers for data loaders. | |
:param input_size: input size of the model. | |
""" | |
log_utils.info(f"prepare for training") | |
accelerator = Accelerator(mixed_precision=mixed_precision) | |
weight_dtype = train_utils.prepare_dtype(mixed_precision) | |
device = accelerator.device | |
max_data_loader_n_workers = min(max_data_loader_n_workers, os.cpu_count()-1) | |
if callable(rating_func_type): | |
rating_func = rating_func_type | |
else: | |
rating_func = train_utils.get_rating_func(rating_func_type) | |
model2, preprocess = utils.load_clip_models(name=clip_model_name, device=device) # RN50x64 | |
dataset = Dataset(dataset_source, verbose=True, condition=lambda img_info: img_info.image_path.is_file()) | |
if data_preprocessor: | |
for img_key, img_info in dataset.items(): | |
img_info = data_preprocessor(img_info) | |
keys = list(dataset.keys()) | |
random.shuffle(keys) | |
dataset = Dataset({k: dataset[k] for k in keys}) | |
num_pos = 0 | |
num_neg = 0 | |
num_mid = 0 | |
for img_key, img_info in dataset.items(): | |
rating = rating_func(img_info) | |
if rating == 10: | |
num_pos += 1 | |
elif rating == 0: | |
num_neg += 1 | |
else: | |
num_mid += 1 | |
log_utils.info(f"num_pos: {num_pos} | num_mid: {num_mid} | num_neg: {num_neg}") | |
train_size = int(len(dataset) * (1 - val_percentage)) | |
val_size = len(dataset) - train_size | |
train_dataset, val_dataset = Dataset(dataset.values()[:train_size]), Dataset(dataset.values()[train_size:]) | |
log_utils.info(f"train_size: {train_size} | val_size: {val_size}") | |
train_dataset, train_loader = train_utils.prepare_dataloader( | |
train_dataset, | |
batch_size=train_batch_size, | |
clip_batch_size=clip_batch_size, | |
model2=model2, | |
preprocess=preprocess, | |
input_size=input_size, | |
rating_func=rating_func, | |
shuffle=shuffle, | |
flip_aug=flip_aug, | |
cache_to_disk=cache_to_disk, | |
cache_path=cache_path, | |
max_data_loader_n_workers=max_data_loader_n_workers, | |
persistent_workers=persistent_workers, | |
device=device, | |
) | |
val_dataset, val_loader = train_utils.prepare_dataloader( | |
val_dataset, | |
batch_size=val_batch_size, | |
clip_batch_size=clip_batch_size, | |
model2=model2, | |
preprocess=preprocess, | |
rating_func=rating_func, | |
shuffle=shuffle, | |
flip_aug=flip_aug, | |
cache_to_disk=cache_to_disk, | |
cache_path=cache_path, | |
max_data_loader_n_workers=max_data_loader_n_workers, | |
persistent_workers=persistent_workers, | |
device=device, | |
) | |
rating_stat = {} | |
for i in range(len(train_dataset)): | |
# to list | |
ratings: torch.Tensor = train_dataset[i]['ratings'] | |
ratings = ratings.squeeze().tolist() | |
for rating in ratings: | |
if rating not in rating_stat: | |
rating_stat[rating] = 0 | |
rating_stat[rating] += 1 | |
log_utils.info("rating_stat:\n", '\n'.join(f'{k}: {v}' for k, v in rating_stat.items())) | |
# prepare model | |
model: mlp.MLP = utils.load_model(resume_path, model_type=mlp_model_type, input_size=input_size, batch_norm=batch_norm, device=device, dtype=weight_dtype) | |
# import prodigyopt | |
# print(f"use Prodigy optimizer | {optimizer_kwargs}") | |
# optimizer_class = prodigyopt.Prodigy | |
# optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2) | |
# choose the loss you want to optimize for | |
criterion = nn.MSELoss(reduction='mean') | |
criterion2 = nn.L1Loss(reduction='mean') | |
model, optimizer, train_loader, val_loader = accelerator.prepare( | |
model, optimizer, train_loader, val_loader | |
) | |
log_utils.info(f"device: {accelerator.device}") | |
# training loop | |
best_loss = 999 # best validation loss | |
total_train_steps = len(train_loader) * num_train_epochs | |
progress_bar = tqdm(range(total_train_steps), position=0, leave=True) | |
print(f"total_train_steps: {total_train_steps}") | |
class LossRecorder: | |
def __init__(self): | |
self.loss_list = [] | |
self.loss_total: float = 0.0 | |
def add(self, *, epoch: int, step: int, loss: float) -> None: | |
if epoch == 0: | |
self.loss_list.append(loss) | |
else: | |
self.loss_total -= self.loss_list[step] | |
self.loss_list[step] = loss | |
self.loss_total += loss | |
def moving_average(self) -> float: | |
return self.loss_total / len(self.loss_list) | |
loss_recorder = LossRecorder() | |
model.requires_grad_(True) | |
save_on_end = False | |
try: | |
for epoch in range(num_train_epochs): | |
model.train() | |
losses = [] | |
losses2 = [] | |
for step, input_data in enumerate(train_loader): | |
optimizer.zero_grad(set_to_none=True) | |
im_emb_arr: torch.Tensor = input_data['im_emb_arrs'].to(accelerator.device).to(dtype=weight_dtype) # shape: (batch_size, input_size) | |
rating: torch.Tensor = input_data['ratings'].to(accelerator.device).to(dtype=weight_dtype) # shape: (batch_size, 1) | |
# randomize the rating | |
# rating_std = 0.5 | |
# rating = rating + torch.randn_like(rating) * rating_std | |
# log_utils.debug(f"x.dtype: {x.dtype} | y.dtype: {y.dtype} | model.dtype: {model.dtype}") | |
with accelerator.autocast(): | |
output = model(im_emb_arr) | |
loss = criterion(output, rating) | |
accelerator.backward(loss) | |
losses.append(loss.detach().item()) | |
optimizer.step() | |
# if step % 1000 == 0: | |
# print('\tEpoch %d | Batch %d | Loss %6.2f' % (epoch, step, loss.item())) | |
# # print(y) | |
progress_bar.update(1) | |
current_loss = loss.detach().item() | |
loss_recorder.add(epoch=epoch, step=step, loss=current_loss) | |
avr_loss: float = loss_recorder.moving_average | |
pbar_logs = { | |
'lr': f"{lr_scheduler.get_last_lr()[0]:.3e}", | |
'epoch': epoch, | |
'loss': avr_loss, | |
} | |
progress_bar.set_postfix(pbar_logs) | |
progress_bar.write('epoch %d | avg loss %6.6f' % (epoch, avr_loss)) | |
# validation | |
if accelerator.is_main_process and epoch > 0 and epoch % val_every_n_epochs == 0: | |
model.eval() | |
with torch.no_grad(): | |
losses = [] | |
losses2 = [] | |
for step, input_data in enumerate(val_loader): | |
# optimizer.zero_grad(set_to_none=True) | |
im_emb_arr = input_data['im_emb_arrs'].to(accelerator.device).to(dtype=weight_dtype) | |
rating = input_data['ratings'].to(accelerator.device).to(dtype=weight_dtype) | |
with accelerator.autocast(): | |
output = model(im_emb_arr) | |
loss = criterion(output, rating) | |
lossMAE = criterion2(output, rating) | |
# loss.backward() | |
losses.append(loss.detach().item()) | |
losses2.append(lossMAE.detach().item()) | |
# optimizer.step() | |
# if step % 1000 == 0: | |
# print('\tValidation - Epoch %d | Batch %d | MSE Loss %6.2f' % (epoch, step, loss.item())) | |
# print('\tValidation - Epoch %d | Batch %d | MAE Loss %6.2f' % (epoch, step, lossMAE.item())) | |
# print(y) | |
current_loss = sum(losses)/len(losses) | |
s = [f"validation - epoch {log_utils.stylize(epoch, log_utils.ANSI.YELLOW)}"] | |
s.append(f"avg MSE loss {log_utils.stylize(current_loss, log_utils.ANSI.GREEN, format_spec='.4f')}") | |
s.append(f"avg MAE loss {log_utils.stylize(sum(losses2)/len(losses2), log_utils.ANSI.YELLOW, format_spec='.4f')}") | |
progress_bar.write(' | '.join(s)) | |
# progress_bar.write('validation - epoch %d | avg MSE loss %6.4f' % (epoch, sum(losses)/len(losses))) | |
# progress_bar.write('validation - epoch %d | avg MAE loss %6.4f' % (epoch, sum(losses2)/len(losses2))) | |
if save_best_model and current_loss < best_loss: | |
best_loss = current_loss | |
progress_bar.write(f"best MSE val loss ({log_utils.stylize(best_loss, log_utils.ANSI.BOLD, log_utils.ANSI.GREEN)}) so far. saving model...") | |
best_save_path = Path(save_path).parent / f"{Path(save_path).stem}_best-MSE{best_loss:.4f}{Path(save_path).suffix}" | |
train_utils.save_model(model, best_save_path, epoch=epoch) | |
progress_bar.write(f"model saved: `{save_path}`") | |
lr_scheduler.step() | |
accelerator.wait_for_everyone() | |
except KeyboardInterrupt: | |
log_utils.warn("KeyboardInterrupt") | |
if input(f"save model to {save_path}? [y/n]") == 'y': | |
save_on_end = True | |
else: | |
save_on_end = True | |
progress_bar.close() | |
model = accelerator.unwrap_model(model) | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process and save_on_end: | |
log_utils.info("saving model...") | |
train_utils.save_model(model, save_path) | |
log_utils.info(f"model saved: `{save_path}`") | |
del accelerator | |
log_utils.success(f"training done. best loss: {best_loss}") | |
# inferece test with dummy samples from the val set, sanity check | |
# log_utils.info("inference test with dummy samples from the val set, sanity check") | |
# model.eval() | |
# output = model(x[:5].to(device)) | |
# log_utils.info(output.size()) | |
# log_utils.info(output) | |