File size: 3,836 Bytes
b92dd65
c10dc95
 
b92dd65
c10dc95
 
 
 
 
 
 
 
7f5929e
c10dc95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beed73c
c10dc95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import shutil
import json
import torch
import random
from pathlib import Path
from torch.utils.data import Dataset
from torchvision import transforms
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from tqdm.auto import tqdm
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, data_dir, prompt, tokenizer, size=512, center_crop=False):
        self.data_dir = Path(data_dir)
        self.prompt = prompt
        self.tokenizer = tokenizer
        self.size = size
        self.center_crop = center_crop

        self.image_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

        self.images = [f for f in self.data_dir.iterdir() if f.is_file() and not str(f).endswith(".txt")]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path)
        if not image.mode == "RGB":
            image = image.convert("RGB")

        image = self.image_transforms(image)
        prompt_ids = self.tokenizer(
            self.prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length
        ).input_ids

        return {"image": image, "prompt_ids": prompt_ids}

def fine_tune_model(instance_data_dir, instance_prompt, model_name, output_dir, seed=1337, resolution=512, train_batch_size=1, max_train_steps=800):
    # Setup
    accelerator = Accelerator()
    set_seed(seed)
    
    tokenizer = CLIPTokenizer.from_pretrained(model_name)
    text_encoder = CLIPTextModel.from_pretrained(model_name)
    vae = AutoencoderKL.from_pretrained(model_name)
    unet = UNet2DConditionModel.from_pretrained(model_name)
    noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")

    dataset = CustomDataset(instance_data_dir, instance_prompt, tokenizer, resolution)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, shuffle=True)

    optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-6)

    unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)

    global_step = 0
    for step, batch in tqdm(enumerate(dataloader), total=max_train_steps):
        latents = vae.encode(batch["image"].to(accelerator.device)).latent_dist.sample() * 0.18215
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        encoder_hidden_states = text_encoder(batch["prompt_ids"].to(accelerator.device))[0]

        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
        accelerator.backward(loss)

        optimizer.step()
        optimizer.zero_grad()
        global_step += 1
        if global_step >= max_train_steps:
            break

    # Save model
    unet = accelerator.unwrap_model(unet)
    unet.save_pretrained(output_dir)
    vae.save_pretrained(output_dir)
    text_encoder.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)