TastyRice commited on
Commit
3ac1891
1 Parent(s): 159411c

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +182 -0
utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import json
6
+ import torch
7
+ import uuid
8
+ from PIL import Image, PngImagePlugin
9
+ from datetime import datetime
10
+ from dataclasses import dataclass
11
+ from typing import Callable, Dict, Optional, Tuple
12
+ from diffusers import (
13
+ DDIMScheduler,
14
+ DPMSolverMultistepScheduler,
15
+ DPMSolverSinglestepScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ EulerDiscreteScheduler,
18
+ )
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+
22
+
23
+ @dataclass
24
+ class StyleConfig:
25
+ prompt: str
26
+ negative_prompt: str
27
+
28
+
29
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
30
+ if randomize_seed:
31
+ seed = random.randint(0, MAX_SEED)
32
+ return seed
33
+
34
+
35
+ def seed_everything(seed: int) -> torch.Generator:
36
+ torch.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+ np.random.seed(seed)
39
+ generator = torch.Generator()
40
+ generator.manual_seed(seed)
41
+ return generator
42
+
43
+
44
+ def parse_aspect_ratio(aspect_ratio: str) -> Optional[Tuple[int, int]]:
45
+ if aspect_ratio == "Custom":
46
+ return None
47
+ width, height = aspect_ratio.split(" x ")
48
+ return int(width), int(height)
49
+
50
+
51
+ def aspect_ratio_handler(
52
+ aspect_ratio: str, custom_width: int, custom_height: int
53
+ ) -> Tuple[int, int]:
54
+ if aspect_ratio == "Custom":
55
+ return custom_width, custom_height
56
+ else:
57
+ width, height = parse_aspect_ratio(aspect_ratio)
58
+ return width, height
59
+
60
+
61
+ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
62
+ scheduler_factory_map = {
63
+ "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(
64
+ scheduler_config, use_karras_sigmas=True
65
+ ),
66
+ "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(
67
+ scheduler_config, use_karras_sigmas=True
68
+ ),
69
+ "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(
70
+ scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"
71
+ ),
72
+ "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
73
+ "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(
74
+ scheduler_config
75
+ ),
76
+ "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
77
+ }
78
+ return scheduler_factory_map.get(name, lambda: None)()
79
+
80
+
81
+ def free_memory() -> None:
82
+ torch.cuda.empty_cache()
83
+ gc.collect()
84
+
85
+
86
+ def preprocess_prompt(
87
+ style_dict,
88
+ style_name: str,
89
+ positive: str,
90
+ negative: str = "",
91
+ add_style: bool = True,
92
+ ) -> Tuple[str, str]:
93
+ p, n = style_dict.get(style_name, style_dict["(None)"])
94
+
95
+ if add_style and positive.strip():
96
+ formatted_positive = p.format(prompt=positive)
97
+ else:
98
+ formatted_positive = positive
99
+
100
+ combined_negative = n
101
+ if negative.strip():
102
+ if combined_negative:
103
+ combined_negative += ", " + negative
104
+ else:
105
+ combined_negative = negative
106
+
107
+ return formatted_positive, combined_negative
108
+
109
+
110
+ def common_upscale(
111
+ samples: torch.Tensor,
112
+ width: int,
113
+ height: int,
114
+ upscale_method: str,
115
+ ) -> torch.Tensor:
116
+ return torch.nn.functional.interpolate(
117
+ samples, size=(height, width), mode=upscale_method
118
+ )
119
+
120
+
121
+ def upscale(
122
+ samples: torch.Tensor, upscale_method: str, scale_by: float
123
+ ) -> torch.Tensor:
124
+ width = round(samples.shape[3] * scale_by)
125
+ height = round(samples.shape[2] * scale_by)
126
+ return common_upscale(samples, width, height, upscale_method)
127
+
128
+
129
+ def load_character_files(character_dir: str) -> Dict[str, str]:
130
+ character_files = {}
131
+ for file in os.listdir(character_dir):
132
+ if file.endswith(".txt"):
133
+ key = f"__{file.split('.')[0]}__" # Create a key like __character__
134
+ character_files[key] = os.path.join(character_dir, file)
135
+ return character_files
136
+
137
+
138
+ def get_random_line_from_file(file_path: str) -> str:
139
+ with open(file_path, "r") as file:
140
+ lines = file.readlines()
141
+ if not lines:
142
+ return ""
143
+ return random.choice(lines).strip()
144
+
145
+
146
+ def add_character(prompt: str, character_files: Dict[str, str]) -> str:
147
+ for key, file_path in character_files.items():
148
+ if key in prompt:
149
+ character_line = get_random_line_from_file(file_path)
150
+ prompt = prompt.replace(key, character_line)
151
+ return prompt
152
+
153
+
154
+ def preprocess_image_dimensions(width, height):
155
+ if width % 8 != 0:
156
+ width = width - (width % 8)
157
+ if height % 8 != 0:
158
+ height = height - (height % 8)
159
+ return width, height
160
+
161
+
162
+ def save_image(image, metadata, output_dir, is_colab):
163
+ if is_colab:
164
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
165
+ filename = f"image_{current_time}.png"
166
+ else:
167
+ filename = str(uuid.uuid4()) + ".png"
168
+ os.makedirs(output_dir, exist_ok=True)
169
+ filepath = os.path.join(output_dir, filename)
170
+ metadata_str = json.dumps(metadata)
171
+ info = PngImagePlugin.PngInfo()
172
+ info.add_text("metadata", metadata_str)
173
+ image.save(filepath, "PNG", pnginfo=info)
174
+ return filepath
175
+
176
+
177
+ def is_google_colab():
178
+ try:
179
+ import google.colab
180
+ return True
181
+ except:
182
+ return False