SunderAli17 commited on
Commit
4ff3ac9
·
verified ·
1 Parent(s): 5d18928

Create sampling.py

Browse files
Files changed (1) hide show
  1. flux/sampling.py +161 -0
flux/sampling.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor
7
+
8
+ from .model import Flux
9
+ from .modules.conditioner import HFEmbedder
10
+
11
+
12
+ def get_noise(
13
+ num_samples: int,
14
+ height: int,
15
+ width: int,
16
+ device: torch.device,
17
+ dtype: torch.dtype,
18
+ seed: int,
19
+ ):
20
+ return torch.randn(
21
+ num_samples,
22
+ 16,
23
+ # allow for packing
24
+ 2 * math.ceil(height / 16),
25
+ 2 * math.ceil(width / 16),
26
+ device=device,
27
+ dtype=dtype,
28
+ generator=torch.Generator(device=device).manual_seed(seed),
29
+ )
30
+
31
+
32
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str) -> dict[str, Tensor]:
33
+ bs, c, h, w = img.shape
34
+ if bs == 1 and not isinstance(prompt, str):
35
+ bs = len(prompt)
36
+
37
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
38
+ if img.shape[0] == 1 and bs > 1:
39
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
40
+
41
+ img_ids = torch.zeros(h // 2, w // 2, 3)
42
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
43
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
44
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
45
+
46
+ if isinstance(prompt, str):
47
+ prompt = [prompt]
48
+ txt = t5(prompt)
49
+ if txt.shape[0] == 1 and bs > 1:
50
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
51
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
52
+
53
+ vec = clip(prompt)
54
+ if vec.shape[0] == 1 and bs > 1:
55
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
56
+
57
+ return {
58
+ "img": img,
59
+ "img_ids": img_ids.to(img.device),
60
+ "txt": txt.to(img.device),
61
+ "txt_ids": txt_ids.to(img.device),
62
+ "vec": vec.to(img.device),
63
+ }
64
+
65
+
66
+ def time_shift(mu: float, sigma: float, t: Tensor):
67
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
68
+
69
+
70
+ def get_lin_function(
71
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
72
+ ) -> Callable[[float], float]:
73
+ m = (y2 - y1) / (x2 - x1)
74
+ b = y1 - m * x1
75
+ return lambda x: m * x + b
76
+
77
+
78
+ def get_schedule(
79
+ num_steps: int,
80
+ image_seq_len: int,
81
+ base_shift: float = 0.5,
82
+ max_shift: float = 1.15,
83
+ shift: bool = True,
84
+ ) -> list[float]:
85
+ # extra step for zero
86
+ timesteps = torch.linspace(1, 0, num_steps + 1)
87
+
88
+ # shifting the schedule to favor high timesteps for higher signal images
89
+ if shift:
90
+ # eastimate mu based on linear estimation between two points
91
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
92
+ timesteps = time_shift(mu, 1.0, timesteps)
93
+
94
+ return timesteps.tolist()
95
+
96
+
97
+ def denoise(
98
+ model: Flux,
99
+ # model input
100
+ img: Tensor,
101
+ img_ids: Tensor,
102
+ txt: Tensor,
103
+ txt_ids: Tensor,
104
+ vec: Tensor,
105
+ timesteps: list[float],
106
+ guidance: float = 4.0,
107
+ id_weight=1.0,
108
+ id=None,
109
+ start_step=0,
110
+ uncond_id=None,
111
+ true_cfg=1.0,
112
+ timestep_to_start_cfg=1,
113
+ neg_txt=None,
114
+ neg_txt_ids=None,
115
+ neg_vec=None,
116
+ ):
117
+ # this is ignored for schnell
118
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
119
+ use_true_cfg = abs(true_cfg - 1.0) > 1e-2
120
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
121
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
122
+ pred = model(
123
+ img=img,
124
+ img_ids=img_ids,
125
+ txt=txt,
126
+ txt_ids=txt_ids,
127
+ y=vec,
128
+ timesteps=t_vec,
129
+ guidance=guidance_vec,
130
+ id=id if i >= start_step else None,
131
+ id_weight=id_weight,
132
+ )
133
+
134
+ if use_true_cfg and i >= timestep_to_start_cfg:
135
+ neg_pred = model(
136
+ img=img,
137
+ img_ids=img_ids,
138
+ txt=neg_txt,
139
+ txt_ids=neg_txt_ids,
140
+ y=neg_vec,
141
+ timesteps=t_vec,
142
+ guidance=guidance_vec,
143
+ id=uncond_id if i >= start_step else None,
144
+ id_weight=id_weight,
145
+ )
146
+ pred = neg_pred + true_cfg * (pred - neg_pred)
147
+
148
+ img = img + (t_prev - t_curr) * pred
149
+
150
+ return img
151
+
152
+
153
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
154
+ return rearrange(
155
+ x,
156
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
157
+ h=math.ceil(height / 16),
158
+ w=math.ceil(width / 16),
159
+ ph=2,
160
+ pw=2,
161
+ )