|
import csv |
|
import dataclasses |
|
import subprocess |
|
from copy import deepcopy |
|
import itertools |
|
from concurrent.futures import ThreadPoolExecutor |
|
import pathlib |
|
from typing import List |
|
import diffusers |
|
import transformers |
|
import safetensors.torch |
|
import torch.utils.data |
|
from tqdm import tqdm |
|
from datetime import datetime |
|
import random |
|
import os |
|
import time |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
torch.manual_seed(0) |
|
random.seed(0) |
|
|
|
|
|
LATENTS_OUTPUT_DIR = pathlib.Path("latents") |
|
CAPTIONS_OUTPUT_DIR = pathlib.Path("captions2") |
|
DANBOORU_ARTISTS_PATH = pathlib.Path("danbooru_artist.csv") |
|
E621_ARTISTS_PATH = pathlib.Path("e621_artist.csv") |
|
LOCK_FILE = "safetensors.lock" |
|
|
|
|
|
device = torch.device("cuda") |
|
dtype = torch.float16 |
|
|
|
|
|
train_logger = SummaryWriter(f"logs/pony_scoreless_{datetime.now().strftime('%Y%m%d_%H%M%S')}") |
|
|
|
|
|
def accumulate_grads(): |
|
batch_size = 1 |
|
epochs = 1 |
|
|
|
tokenizer = create_tokenizer(device) |
|
|
|
model_a = diffusers.StableDiffusionXLPipeline.from_single_file( |
|
"NoobAI-XL-v1.1.safetensors", |
|
torch_dtype=dtype, |
|
) |
|
delattr(model_a, "vae") |
|
model_a.unet.to(device=device) |
|
|
|
model_a.unet.enable_gradient_checkpointing() |
|
model_a.text_encoder.to(device=device) |
|
model_a.text_encoder.gradient_checkpointing_enable() |
|
model_a.text_encoder_2.to(device=device) |
|
model_a.text_encoder_2.gradient_checkpointing_enable() |
|
model_a.text_encoder_combined = CombinedCLIPTextEncoder(model_a.text_encoder, model_a.text_encoder_2, batch_size) |
|
|
|
model_b = diffusers.StableDiffusionXLPipeline.from_single_file( |
|
"animagine-xl-4.0.safetensors", |
|
torch_dtype=dtype, |
|
) |
|
delattr(model_b, "vae") |
|
model_b.unet.to(device=device) |
|
|
|
model_b.unet.enable_gradient_checkpointing() |
|
model_b.text_encoder.to(device=device) |
|
model_b.text_encoder.gradient_checkpointing_enable() |
|
model_b.text_encoder_2.to(device=device) |
|
model_b.text_encoder_2.gradient_checkpointing_enable() |
|
model_b.text_encoder_combined = CombinedCLIPTextEncoder(model_b.text_encoder, model_b.text_encoder_2, batch_size) |
|
|
|
model_a.unet.eval() |
|
model_a.text_encoder.eval() |
|
model_a.text_encoder_2.eval() |
|
model_b.unet.eval() |
|
model_b.text_encoder.eval() |
|
model_b.text_encoder_2.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scheduler = create_scheduler(device) |
|
data_loader = get_data_loader(tokenizer, batch_size) |
|
total_steps = 0 |
|
|
|
log_scalars_a = {} |
|
log_scalars_b = {} |
|
log_scalars_sync = {} |
|
|
|
n1 = torch.tensor(-1, device=device, dtype=torch.long) |
|
ldexp_offset = torch.tensor(20, device=device, dtype=torch.long) |
|
def create_hook(param, k, log_scalars): |
|
param.grad = torch.zeros_like(param) |
|
log_scalars[k] = ldexp_offset.clone() |
|
|
|
def hook(grad): |
|
nonlocal param, log_scalars, k |
|
while True: |
|
new_grad = param.grad + grad.abs().ldexp(log_scalars[k]) |
|
if not new_grad.isfinite().all(): |
|
log_scalars[k] -= 1 |
|
param.grad.ldexp_(n1) |
|
else: |
|
break |
|
|
|
param.grad.copy_(new_grad) |
|
return param.grad |
|
|
|
return hook |
|
|
|
for model, log_scalars in ((model_a, log_scalars_a), (model_b, log_scalars_b)): |
|
for k, v in get_params(model): |
|
v.register_hook(create_hook(v, k, log_scalars)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
noisy_latents = timesteps = time_ids = None |
|
def get_pred(args): |
|
nonlocal noisy_latents, timesteps, time_ids |
|
model, tokens = args |
|
txt = model.text_encoder_combined(tokens[0]) |
|
return model.unet( |
|
noisy_latents, |
|
timesteps, |
|
encoder_hidden_states=txt["conds"], |
|
added_cond_kwargs={ |
|
"text_embeds": txt["pooled"], |
|
"time_ids": time_ids, |
|
}, |
|
).sample |
|
|
|
params = list(v for k, v in itertools.chain(get_params(model_a), get_params(model_b))) |
|
with ThreadPoolExecutor(max_workers=2) as worker: |
|
for epoch_i in range(epochs): |
|
for step_i, (latent_infos, tokens_a, tokens_b, post_ids) in enumerate(tqdm(data_loader)): |
|
latents = torch.cat([latent_info["latent"] for latent_info in latent_infos], dim=0).to(device=device, dtype=dtype) |
|
crop_hw = torch.stack([latent_info["crop_hw"] for latent_info in latent_infos]).to(device=device) |
|
orig_hw = torch.stack([latent_info["orig_hw"] for latent_info in latent_infos]).to(device=device) |
|
|
|
noise, noisy_latents, timesteps = get_noise_noisy_latents_and_timesteps(scheduler, latents) |
|
time_ids = get_add_time_ids(orig_hw, crop_hw) |
|
|
|
|
|
|
|
|
|
|
|
pred_a, pred_b = worker.map(get_pred, ((model_a, tokens_a), (model_b, tokens_b))) |
|
|
|
mse = torch.nn.functional.mse_loss(pred_a, pred_b, reduction="none").flatten(start_dim=1).mean(dim=-1) |
|
loss = (mse / mse.detach()).mean() |
|
|
|
train_logger.add_scalar("grads/loss", loss.item(), total_steps) |
|
train_logger.add_scalar("grads/loss_raw", mse.mean().item(), total_steps) |
|
train_logger.add_scalar("grads/timestep", timesteps[0].item(), total_steps) |
|
|
|
torch.autograd.grad(loss, params, retain_graph=False, allow_unused=True) |
|
|
|
for (k, v_a), (k_b, v_b) in zip(get_params(model_a), get_params(model_b)): |
|
assert k == k_b |
|
if v_a.grad is not None and v_b.grad is not None: |
|
while log_scalars_a[k] > log_scalars_b[k]: |
|
log_scalars_a[k] -= 1 |
|
v_a.grad.ldexp_(n1) |
|
while log_scalars_b[k] > log_scalars_a[k]: |
|
log_scalars_b[k] -= 1 |
|
v_b.grad.ldexp_(n1) |
|
log_scalars_sync[k] = log_scalars_a[k] |
|
|
|
if (step_i + 1) % 10 == 0: |
|
train_logger.add_scalar("grads/max_a", max(v.grad.max().item() for k, v in get_params(model_a) if v.grad is not None), total_steps) |
|
train_logger.add_scalar("grads/max_b", max(v.grad.max().item() for k, v in get_params(model_b) if v.grad is not None), total_steps) |
|
|
|
if (step_i + 1) % 1000 == 0: |
|
save_grads(model_a, "grads_a.safetensors", first=True) |
|
safetensors.torch.save_file(log_scalars_sync, "log_scalars.safetensors") |
|
save_grads(model_b, "grads_b.safetensors", last=True) |
|
|
|
total_steps += batch_size |
|
|
|
|
|
def get_modules(model): |
|
return itertools.chain( |
|
prefix_iter(model.unet.named_modules(), "unet."), |
|
prefix_iter(model.text_encoder.named_modules(), "text_encoder."), |
|
prefix_iter(model.text_encoder_2.named_modules(), "text_encoder_2."), |
|
) |
|
|
|
|
|
def get_params(model): |
|
return itertools.chain( |
|
prefix_iter(model.unet.named_parameters(), "unet."), |
|
prefix_iter(model.text_encoder.named_parameters(), "text_encoder."), |
|
prefix_iter(model.text_encoder_2.named_parameters(), "text_encoder_2."), |
|
) |
|
|
|
|
|
def prefix_iter(item_iter, prefix): |
|
return ((prefix + k, v) for k, v in item_iter) |
|
|
|
|
|
def save_grads(model, path, first=False, last=False): |
|
if first: |
|
wait_for_lock_removal() |
|
|
|
safetensors.torch.save_file( |
|
{k: v.grad.cpu().contiguous() for k, v in get_params(model) if v.grad is not None}, |
|
path, |
|
) |
|
|
|
if last: |
|
|
|
with open(LOCK_FILE, "w") as f: |
|
f.write("pending download") |
|
print("Checkpoint pair saved, lock file created.") |
|
|
|
|
|
def wait_for_lock_removal(poll_interval=5): |
|
"""Wait until the lock file is removed by the local download script.""" |
|
while os.path.exists(LOCK_FILE): |
|
time.sleep(poll_interval) |
|
|
|
|
|
def create_scheduler(device: torch.device): |
|
scheduler = diffusers.DDPMScheduler( |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
num_train_timesteps=1000, |
|
clip_sample=False, |
|
) |
|
|
|
inv_snr = ((1-scheduler.alphas_cumprod) / scheduler.alphas_cumprod).to(device) |
|
scheduler.inv_snr = inv_snr |
|
scheduler.inv_snr_weights = inv_snr / inv_snr.sum() |
|
return scheduler |
|
|
|
|
|
def debiased_loss_scaling(timesteps, noise_scheduler): |
|
return noise_scheduler.inv_snr[timesteps] |
|
|
|
|
|
def get_noise_noisy_latents_and_timesteps(scheduler, latents): |
|
batch_size = latents.shape[0] |
|
noise = torch.randn_like(latents, device=latents.device) |
|
|
|
timesteps = torch.multinomial(scheduler.inv_snr_weights, batch_size) |
|
noisy_latents = scheduler.add_noise(latents, noise, timesteps) |
|
return noise, noisy_latents, timesteps |
|
|
|
|
|
def get_add_time_ids(original_size, crops_coords_top_left): |
|
add_time_ids = torch.cat([ |
|
original_size, |
|
crops_coords_top_left, |
|
torch.tensor([[1024]*2], device=original_size.device).expand(len(original_size), -1), |
|
], dim=1) |
|
|
|
return add_time_ids |
|
|
|
|
|
def get_data_loader(tokenizer, batch_size: int): |
|
return torch.utils.data.DataLoader( |
|
PromptDataset(tokenizer), |
|
batch_size=batch_size, |
|
shuffle=True, |
|
collate_fn=lambda x: zip(*x), |
|
) |
|
|
|
|
|
@dataclasses.dataclass |
|
class ArtistScore: |
|
artist_tag: str |
|
count: int |
|
|
|
|
|
class PromptDataset(torch.utils.data.Dataset): |
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
self.latent_paths = list(LATENTS_OUTPUT_DIR.iterdir()) |
|
with open(DANBOORU_ARTISTS_PATH, "r", encoding='utf-8') as f: |
|
reader = csv.DictReader(f) |
|
self.b_artists = [ArtistScore(r["trigger"], int(r["count"])) for r in reader if r["artist"] != "banned_artist"] |
|
self.b_artists.sort(key=lambda t: t.count, reverse=True) |
|
self.b_artist_scores = torch.tensor(list(map(lambda t: t.count, self.b_artists)), device=device, dtype=torch.float32) |
|
self.b_artist_scores /= self.b_artist_scores.sum() |
|
|
|
with open(E621_ARTISTS_PATH, "r", encoding='utf-8') as f: |
|
reader = csv.DictReader(f,) |
|
self.a_artists = self.b_artists + [ArtistScore(r["trigger"], int(r["count"])) for r in reader if r["artist"] not in ["conditional_dnp", "avoid_posting", "unknown_artist", "third-party_edit", "sound_warning", "anonymous_artist"]] |
|
self.a_artists.sort(key=lambda t: t.count, reverse=True) |
|
self.a_artist_scores = torch.tensor(list(map(lambda t: t.count, self.a_artists)), device=device, dtype=torch.float32) |
|
self.a_artist_scores /= self.a_artist_scores.sum() |
|
|
|
self.a_prefix = "masterpiece, best quality, newest, absurdres, highres, safe, " |
|
self.b_suffix = ", masterpiece, high score, great score, absurdres" |
|
|
|
def __len__(self): |
|
return len(self.latent_paths) |
|
|
|
def __getitem__(self, item): |
|
post_id = self.latent_paths[item].stem |
|
latent = safetensors.torch.load_file(LATENTS_OUTPUT_DIR / f"{post_id}.safetensors", device=str(device)) |
|
caption = (CAPTIONS_OUTPUT_DIR / f"{post_id}.txt").read_text() |
|
|
|
caption_a = self.a_prefix + caption |
|
caption_b = caption + self.b_suffix |
|
|
|
if item % 2 == 0: |
|
artist_a = self.a_artists[torch.multinomial(self.a_artist_scores, 1).item()] |
|
caption_a = artist_a.artist_tag + ", " + caption_a |
|
else: |
|
artist_b = self.b_artists[torch.multinomial(self.b_artist_scores, 1).item()] |
|
caption_b = artist_b.artist_tag + ", " + caption_b |
|
|
|
tokens_a = self.tokenizer.chunk_tokens(self.tokenizer([caption_a.replace("),", ") ,")])) |
|
tokens_b = self.tokenizer.chunk_tokens(self.tokenizer([caption_b.replace("),", ") ,")])) |
|
return latent, tokens_a, tokens_b, post_id |
|
|
|
|
|
class CombinedCLIPTextEncoder(torch.nn.Module): |
|
def __init__(self, clip_l, clip_g, batch_size): |
|
super().__init__() |
|
assert batch_size == 1 |
|
self.clip_l = clip_l |
|
self.clip_g = clip_g |
|
|
|
def forward(self, tokens): |
|
tokens_clip_l = tokens["clip_l"].copy() |
|
del tokens_clip_l["prompt_starts"] |
|
|
|
tokens_clip_g = tokens["clip_g"].copy() |
|
clip_g_starts = tokens_clip_g.pop("prompt_starts") |
|
|
|
clip_l_encoded = self.clip_l(**tokens_clip_l, output_hidden_states=True, return_dict=True) |
|
clip_g_encoded = self.clip_g(**tokens_clip_g, output_hidden_states=True, return_dict=True) |
|
combined_encoded = torch.cat([clip_l_encoded["hidden_states"][-2], clip_g_encoded["hidden_states"][-2]], dim=-1) |
|
combined_encoded_reshape = combined_encoded.reshape(1, -1, 2048) |
|
|
|
return { |
|
"conds": combined_encoded_reshape, |
|
"pooled": clip_g_encoded.text_embeds[clip_g_starts], |
|
} |
|
|
|
|
|
def create_tokenizer(device: torch.device): |
|
tokenizer_l = transformers.CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
|
tokenizer_g = transformers.CLIPTokenizer.from_pretrained("laion/CLIP-ViT-g-14-laion2B-s34B-b88K") |
|
return CombinedCLIPTokenizer(tokenizer_l, tokenizer_g, device) |
|
|
|
|
|
class CombinedCLIPTokenizer(torch.nn.Module): |
|
comma_token = 267 |
|
|
|
def __init__(self, tokenizer_l, tokenizer_g, output_device: torch.device): |
|
super().__init__() |
|
self.tokenizer_l = tokenizer_l |
|
self.tokenizer_g = tokenizer_g |
|
self.output_device = output_device |
|
|
|
def forward(self, prompts: List[str]) -> dict: |
|
tokens_l = self.tokenizer_l(prompts, add_special_tokens=False) |
|
return { |
|
"clip_l": tokens_l, |
|
"clip_g": deepcopy(tokens_l), |
|
} |
|
|
|
def chunk_tokens(self, tokens: dict): |
|
return { |
|
"clip_l": self._chunk_tokens_impl(self.tokenizer_l, tokens["clip_l"]), |
|
"clip_g": self._chunk_tokens_impl(self.tokenizer_g, tokens["clip_g"]), |
|
} |
|
|
|
def _chunk_tokens_impl(self, tokenizer, tokens: dict): |
|
input_ids = [] |
|
attention_masks = [] |
|
chunk_counts = [] |
|
|
|
for prompt, mask in zip(tokens["input_ids"], tokens["attention_mask"]): |
|
last_comma = 0 |
|
current_chunk = [] |
|
chunks = [] |
|
chunks_attn = [] |
|
|
|
def next_chunk(): |
|
nonlocal current_chunk |
|
current_chunk = [tokenizer.bos_token_id] + current_chunk + [tokenizer.eos_token_id] |
|
num_tokens = len(current_chunk) |
|
|
|
current_chunk.extend([tokenizer.pad_token_id] * (77 - num_tokens)) |
|
chunks.append(current_chunk) |
|
current_chunk = [] |
|
chunks_attn.append([1] * num_tokens + [0] * (77 - num_tokens)) |
|
|
|
for token_i, token in enumerate(prompt): |
|
is_last_token = token_i == len(prompt) - 1 |
|
seq_suffix = prompt[last_comma:token_i + int(is_last_token)] |
|
|
|
if token == self.comma_token or is_last_token: |
|
if len(current_chunk) + len(seq_suffix) > 77 - 2: |
|
next_chunk() |
|
seq_suffix = prompt[last_comma+1:token_i + int(is_last_token)] |
|
|
|
|
|
current_chunk.extend(seq_suffix) |
|
last_comma = token_i |
|
|
|
if current_chunk or not chunks: |
|
next_chunk() |
|
|
|
chunk_counts.append(len(chunks)) |
|
input_ids.extend(chunks) |
|
attention_masks.extend(chunks_attn) |
|
|
|
return { |
|
"input_ids": torch.tensor(input_ids, device=self.output_device), |
|
"attention_mask": torch.tensor(attention_masks, device=self.output_device), |
|
"prompt_starts": torch.tensor([0] + chunk_counts[:-1], device=self.output_device).cumsum(dim=0), |
|
} |
|
|
|
|
|
def shutdown_machine(): |
|
"""Shutdown the machine. Adjust the command as necessary for your environment.""" |
|
|
|
wait_for_lock_removal() |
|
print("All checkpoints have been downloaded. Shutting down the machine.") |
|
try: |
|
subprocess.run("runpodctl stop pod $RUNPOD_POD_ID", shell=True, check=True) |
|
except Exception as e: |
|
print(f"Error shutting down: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
accumulate_grads() |
|
shutdown_machine() |
|
|