Singularity666 commited on
Commit
c10dc95
·
verified ·
1 Parent(s): e9060aa

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +94 -47
main.py CHANGED
@@ -1,50 +1,97 @@
1
- # main.py
2
-
3
  import os
 
 
4
  import torch
5
- from torch import autocast
6
- from diffusers import StableDiffusionPipeline, DDIMScheduler
7
- from huggingface_hub import HfApi
8
- from app import launch_gradio_app
9
- from dreambooth import train_dreambooth
10
-
11
- def fine_tune_model(instance_images, class_images, instance_prompt, class_prompt, num_train_steps=800):
12
- model_name = "runwayml/stable-diffusion-v1-5"
13
- output_dir = "dreambooth_model"
14
-
15
- train_dreambooth(
16
- pretrained_model_name_or_path=model_name,
17
- instance_data_dir=instance_images,
18
- class_data_dir=class_images,
19
- output_dir=output_dir,
20
- instance_prompt=instance_prompt,
21
- class_prompt=class_prompt,
22
- num_train_steps=num_train_steps
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- return output_dir
26
-
27
- def load_model(model_path):
28
- pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
29
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
30
- pipe.enable_xformers_memory_efficient_attention()
31
- return pipe
32
-
33
- def generate_images(pipe, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
34
- with torch.autocast("cuda"), torch.inference_mode():
35
- images = pipe(
36
- prompt, height=int(height), width=int(width),
37
- negative_prompt=negative_prompt,
38
- num_images_per_prompt=int(num_samples),
39
- num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
40
- generator=torch.Generator(device='cuda')
41
- ).images
42
- return images
43
-
44
- def push_to_huggingface(model_path, repo_name):
45
- api = HfApi()
46
- api.upload_folder(folder_path=model_path, repo_id=repo_name)
47
-
48
- if __name__ == "__main__":
49
- repo_name = "Singularity666/Magix"
50
- launch_gradio_app(fine_tune_model, load_model, generate_images, push_to_huggingface, repo_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import shutil
3
+ import json
4
  import torch
5
+ import random
6
+ from pathlib import Path
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
10
+ from transformers import CLIPTextModel, CLIPTokenizer
11
+ from accelerate import Accelerator
12
+ from tqdm.auto import tqdm
13
+
14
+ class CustomDataset(Dataset):
15
+ def __init__(self, data_dir, prompt, tokenizer, size=512, center_crop=False):
16
+ self.data_dir = Path(data_dir)
17
+ self.prompt = prompt
18
+ self.tokenizer = tokenizer
19
+ self.size = size
20
+ self.center_crop = center_crop
21
+
22
+ self.image_transforms = transforms.Compose([
23
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
24
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.5], [0.5])
27
+ ])
28
+
29
+ self.images = [f for f in self.data_dir.iterdir() if f.is_file() and not str(f).endswith(".txt")]
30
+
31
+ def __len__(self):
32
+ return len(self.images)
33
+
34
+ def __getitem__(self, idx):
35
+ image_path = self.images[idx]
36
+ image = Image.open(image_path)
37
+ if not image.mode == "RGB":
38
+ image = image.convert("RGB")
39
+
40
+ image = self.image_transforms(image)
41
+ prompt_ids = self.tokenizer(
42
+ self.prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length
43
+ ).input_ids
44
+
45
+ return {"image": image, "prompt_ids": prompt_ids}
46
+
47
+ def fine_tune_model(instance_data_dir, instance_prompt, model_name, output_dir, seed=1337, resolution=512, train_batch_size=1, max_train_steps=800):
48
+ # Setup
49
+ accelerator = Accelerator()
50
+ set_seed(seed)
51
 
52
+ tokenizer = CLIPTokenizer.from_pretrained(model_name)
53
+ text_encoder = CLIPTextModel.from_pretrained(model_name)
54
+ vae = AutoencoderKL.from_pretrained(model_name)
55
+ unet = UNet2DConditionModel.from_pretrained(model_name)
56
+ noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
57
+
58
+ dataset = CustomDataset(instance_data_dir, instance_prompt, tokenizer, resolution)
59
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
60
+
61
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-6)
62
+
63
+ unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
64
+ vae.to(accelerator.device)
65
+ text_encoder.to(accelerator.device)
66
+
67
+ global_step = 0
68
+ for step, batch in tqdm(enumerate(dataloader), total=max_train_steps):
69
+ latents = vae.encode(batch["image"].to(accelerator.device)).latent_dist.sample() * 0.18215
70
+ noise = torch.randn_like(latents)
71
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
72
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
73
+ encoder_hidden_states = text_encoder(batch["prompt_ids"].to(accelerator.device))[0]
74
+
75
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
76
+
77
+ loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
78
+ accelerator.backward(loss)
79
+
80
+ optimizer.step()
81
+ optimizer.zero_grad()
82
+ global_step += 1
83
+ if global_step >= max_train_steps:
84
+ break
85
+
86
+ # Save model
87
+ unet = accelerator.unwrap_model(unet)
88
+ unet.save_pretrained(output_dir)
89
+ vae.save_pretrained(output_dir)
90
+ text_encoder.save_pretrained(output_dir)
91
+ tokenizer.save_pretrained(output_dir)
92
+
93
+ def set_seed(seed):
94
+ random.seed(seed)
95
+ torch.manual_seed(seed)
96
+ if torch.cuda.is_available():
97
+ torch.cuda.manual_seed_all(seed)