sinder / main.py
haoqiwang's picture
add files
9ae1b1e
#!/usr/bin/env python
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from sinder import (
get_neighbor_loss,
get_tokens,
load_data,
load_model,
load_visual_data,
pca_array,
replace_back,
replace_linear_addition_noqk,
)
os.environ['XFORMERS_DISABLED'] = '1'
torch.set_float32_matmul_precision('high')
def parse_args():
parser = argparse.ArgumentParser(description='Beautify network')
parser.add_argument(
'--model', type=str, default='dinov2_vitg14', help='config file'
)
parser.add_argument('--work_dir', type=str, default='results')
parser.add_argument('--resolution', type=int, default=518)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--max_iter', type=int, default=30000)
parser.add_argument('--num_train_max', type=int, default=30000)
parser.add_argument('--mask_thr', type=float, default=4)
parser.add_argument('--skip_less_than', type=int, default=3)
parser.add_argument('--visual_size', type=int, default=448 * 2)
parser.add_argument('--kernel', type=int, default=3)
parser.add_argument('--save_at_skip', type=int, nargs='+', default=[75])
parser.add_argument('--limit_layers', type=int, default=10)
args = parser.parse_args()
return args
def prepare_train(args, model):
model.train()
all_params = []
for name, param in model.named_parameters():
param.requires_grad = False
replace_linear_addition_noqk(model, 'model')
for name, param in model.named_parameters():
if '.epsilon' in name and param.requires_grad is True:
all_params.append(param)
grad_params = []
for name, param in model.named_parameters():
if param.requires_grad:
grad_params.append(name)
assert len(grad_params) == len(all_params)
print(len(grad_params), grad_params)
print(len(all_params), all_params)
optimizer = torch.optim.SGD(
all_params,
lr=args.lr,
momentum=0.9,
)
return optimizer
def save_model(args, model):
print('save model')
model.eval()
replace_back(model, 'model')
torch.save(model.state_dict(), args.folder / 'model.pt')
def train(args, model, dataset, optimizer, visual_dataset):
print('training')
skip_history = [False] * 1000
model.train()
for global_iter in tqdm(range(args.max_iter)):
img = dataset[global_iter % len(dataset)]
H = img.shape[1] // model.patch_size
W = img.shape[2] // model.patch_size
density = np.array(skip_history[-1000:]).astype(float).mean()
print(f'{global_iter=} {W=} {H=} {density=:.2f}')
for percent in args.save_at_skip:
if percent / 100 <= density:
print(f'save checkpoint at {density=}')
args.save_at_skip.remove(percent)
torch.save(model, args.folder / f'checkpoint_p{percent}.pth')
if len(args.save_at_skip) == 0:
break
model.zero_grad()
model.train()
with torch.enable_grad():
image_batch = img.unsqueeze(0).cuda()
result = get_neighbor_loss(
model,
image_batch,
skip_less_than=args.skip_less_than,
mask_thr=args.mask_thr,
kernel=args.kernel,
)
if result is None:
skip_history.append(True)
print('no loss, skip')
else:
skip_history.append(False)
(
layer,
loss,
I,
J,
T,
alpha,
mask_angle,
x_token,
) = result
print(
f'{global_iter=}, {layer=}, {density=}, {alpha=:.2f}, {len(I)=}, '
f'{loss.item()=:.2f}'
)
if torch.isnan(loss).any():
print('nan loss, skip')
continue
loss.backward()
# set some grad to 0
if args.limit_layers:
with torch.no_grad():
for t in range(layer - args.limit_layers + 1):
for p in model.blocks[t].parameters():
p.grad = None
has_nan = False
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f'nan grad at {name}, skip')
has_nan = True
if has_nan:
continue
optimizer.step()
# visualize
if global_iter % 100 == 0:
try:
print(f'visualization at {global_iter=}')
pca_img = pca_array(x_token)
pca_img.save(args.folder / 'pca.png')
mask_img = Image.fromarray(
(mask_angle * 255)
.detach()
.cpu()
.numpy()
.astype(np.uint8)
).resize((W * 7, H * 7), resample=Image.NEAREST)
mask_img.save(args.folder / 'mask.png')
Image.fromarray(
(
(
img.permute((1, 2, 0)).cpu().numpy() * 0.22
+ 0.45
)
* 255
)
.clip(0, 255)
.astype(np.uint8)
).save(args.folder / 'img.png')
if global_iter % 1000 == 0:
pca_img.save(args.folder / f'{global_iter:05}_pca.png')
mask_img.save(
args.folder / f'{global_iter:05}_mask.png'
)
Image.fromarray(
(
(
img.permute((1, 2, 0)).cpu().numpy() * 0.22
+ 0.45
)
* 255
)
.clip(0, 255)
.astype(np.uint8)
).save(args.folder / f'{global_iter:05}_img.png')
except Exception as e:
print(e)
for d in range(len(visual_dataset)):
visual_image = visual_dataset[d]
visual_tokens_all = get_tokens(model, visual_image)
visual_tokens, visual_tokens_cls = zip(*visual_tokens_all)
pca_img = pca_array(visual_tokens[-1])
pca_img.save(args.folder / f'{d}_pca.png')
if global_iter % 500 == 0:
pca_img.save(
args.folder / f'{global_iter:05}_{d}_pca.png'
)
torch.save(model, args.folder / 'checkpoint.pth')
def main():
print('Start beautify')
args = parse_args()
name = f'res{args.resolution}_lr{args.lr}_{args.num_train_max}_skipless{args.skip_less_than}_maskthr{args.mask_thr}_limit{args.limit_layers}_ker{args.kernel}'
args.folder = Path(args.work_dir) / name
os.makedirs(args.folder, exist_ok=True)
print(args)
print(' '.join(sys.argv))
print(f'work dir {args.folder}')
model = load_model(args.model)
dataset = load_data(args, model)
visual_dataset = load_visual_data(args, model)
optimizer = prepare_train(args, model)
train(args, model, dataset, optimizer, visual_dataset)
save_model(args, model)
if __name__ == '__main__':
main()