1aurent commited on
Commit
256da70
1 Parent(s): 0394c27
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: Ic Light
3
  emoji: 👁
4
  colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.40.0
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
1
  ---
2
+ title: Refiners IC-Light
3
  emoji: 👁
4
  colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.1.0
8
+ app_file: src/app.py
9
  pinned: false
10
  license: mit
11
  ---
examples/bunny.png ADDED
examples/chair.png ADDED
examples/plant.png ADDED
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ [tool.ruff]
3
+ line-length = 120
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git+https://github.com/finegrain-ai/refiners@06204731093d8055e65b21b4da2ce586737d6ea4
2
+ pillow-heif>=0.18.0
src/app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr # pyright: ignore[reportMissingTypeStubs]
2
+ import pillow_heif # pyright: ignore[reportMissingTypeStubs]
3
+ import spaces # pyright: ignore[reportMissingTypeStubs]
4
+ import torch
5
+ from PIL import Image
6
+ from refiners.fluxion.utils import manual_seed, no_grad
7
+
8
+ from utils import LightingPreference, load_ic_light, resize_modulo_8
9
+
10
+ pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType]
11
+ pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType]
12
+
13
+ TITLE = """
14
+ # IC-Light with Refiners
15
+ """
16
+
17
+ # initialize the enhancer, on the cpu
18
+ DEVICE_CPU = torch.device("cpu")
19
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
20
+ ic_light = load_ic_light(device=DEVICE_CPU, dtype=DTYPE)
21
+
22
+ # "move" the enhancer to the gpu, this is handled/intercepted by Zero GPU
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ ic_light.to(device=DEVICE, dtype=DTYPE)
25
+ ic_light.device = DEVICE
26
+ ic_light.dtype = DTYPE
27
+ ic_light.solver = ic_light.solver.to(device=DEVICE, dtype=DTYPE)
28
+
29
+
30
+ @spaces.GPU
31
+ @no_grad()
32
+ def process(
33
+ image: Image.Image,
34
+ light_pref: str,
35
+ prompt: str,
36
+ negative_prompt: str,
37
+ strength_first_pass: float,
38
+ strength_second_pass: float,
39
+ condition_scale: float,
40
+ num_inference_steps: int,
41
+ seed: int,
42
+ ) -> Image.Image:
43
+ assert image.mode == "RGBA"
44
+ assert 0 <= strength_second_pass <= 1
45
+ assert 0 <= strength_first_pass <= 1
46
+ assert num_inference_steps > 0
47
+ assert seed >= 0
48
+
49
+ # set the seed
50
+ manual_seed(seed)
51
+
52
+ # resize image to ~768x768
53
+ image = resize_modulo_8(image, 768)
54
+
55
+ # split RGB and alpha channel
56
+ mask = image.getchannel("A")
57
+ image = image.convert("RGB")
58
+
59
+ # compute embeddings
60
+ clip_text_embedding = ic_light.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
61
+ ic_light.set_ic_light_condition(image=image, mask=mask)
62
+
63
+ # get the light_pref_image
64
+ light_pref_image = LightingPreference.from_str(value=light_pref).get_init_image(
65
+ width=image.width,
66
+ height=image.height,
67
+ interval=(0.2, 0.8),
68
+ )
69
+
70
+ # if no light preference is provided, do a full strength first pass
71
+ if light_pref_image is None:
72
+ x = torch.randn_like(ic_light._ic_light_condition) # pyright: ignore[reportPrivateUsage]
73
+ strength_first_pass = 1.0
74
+ else:
75
+ x = ic_light.lda.image_to_latents(light_pref_image)
76
+ x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=0)
77
+
78
+ # configure the first pass
79
+ num_steps = int(round(num_inference_steps / strength_first_pass))
80
+ first_step = int(num_steps * (1 - strength_first_pass))
81
+ ic_light.set_inference_steps(num_steps, first_step)
82
+
83
+ # first pass
84
+ for step in ic_light.steps:
85
+ x = ic_light(
86
+ x,
87
+ step=step,
88
+ clip_text_embedding=clip_text_embedding,
89
+ condition_scale=condition_scale,
90
+ )
91
+
92
+ # configure the second pass
93
+ num_steps = int(round(num_inference_steps / strength_second_pass))
94
+ first_step = int(num_steps * (1 - strength_second_pass))
95
+ ic_light.set_inference_steps(num_steps, first_step)
96
+
97
+ # initialize the latents
98
+ x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=first_step)
99
+
100
+ # second pass
101
+ for step in ic_light.steps:
102
+ x = ic_light(
103
+ x,
104
+ step=step,
105
+ clip_text_embedding=clip_text_embedding,
106
+ condition_scale=condition_scale,
107
+ )
108
+
109
+ return ic_light.lda.latents_to_image(x)
110
+
111
+
112
+ with gr.Blocks() as demo:
113
+ gr.Markdown(TITLE)
114
+
115
+ with gr.Row():
116
+ with gr.Column():
117
+ input_image = gr.Image(type="pil", label="Input Image", image_mode="RGBA")
118
+ run_button = gr.Button(value="Relight Image")
119
+ with gr.Column():
120
+ output_image = gr.Image(label="Result")
121
+
122
+ with gr.Accordion("Advanced Settings", open=True):
123
+ prompt = gr.Textbox(
124
+ label="Prompt",
125
+ placeholder="bright green neon light, best quality, highres",
126
+ )
127
+ neg_prompt = gr.Textbox(
128
+ label="Negative Prompt",
129
+ placeholder="worst quality, low quality, normal quality",
130
+ )
131
+ light_pref = gr.Radio(
132
+ choices=["None", "Left", "Right", "Top", "Bottom"],
133
+ label="Light direction preference",
134
+ value="None",
135
+ )
136
+ seed = gr.Slider(
137
+ label="Seed",
138
+ minimum=0,
139
+ maximum=100_000,
140
+ value=69_420,
141
+ step=1,
142
+ )
143
+ condition_scale = gr.Slider(
144
+ label="Condition scale",
145
+ minimum=0.5,
146
+ maximum=2,
147
+ value=1.25,
148
+ step=0.05,
149
+ )
150
+ num_inference_steps = gr.Slider(
151
+ label="Number of inference steps",
152
+ minimum=1,
153
+ maximum=50,
154
+ value=25,
155
+ step=1,
156
+ )
157
+ with gr.Row():
158
+ strength_first_pass = gr.Slider(
159
+ label="Strength of the first pass",
160
+ minimum=0,
161
+ maximum=1,
162
+ value=0.9,
163
+ step=0.1,
164
+ )
165
+ strength_second_pass = gr.Slider(
166
+ label="Strength of the second pass",
167
+ minimum=0,
168
+ maximum=1,
169
+ value=0.5,
170
+ step=0.1,
171
+ )
172
+
173
+ run_button.click(
174
+ fn=process,
175
+ inputs=[
176
+ input_image,
177
+ light_pref,
178
+ prompt,
179
+ neg_prompt,
180
+ strength_first_pass,
181
+ strength_second_pass,
182
+ condition_scale,
183
+ num_inference_steps,
184
+ seed,
185
+ ],
186
+ outputs=output_image,
187
+ )
188
+
189
+ gr.Examples( # pyright: ignore[reportUnknownMemberType]
190
+ examples=[
191
+ [
192
+ "examples/plant.png",
193
+ "None",
194
+ "blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF",
195
+ "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white",
196
+ 0.9,
197
+ 0.5,
198
+ 1.25,
199
+ 25,
200
+ 69_420,
201
+ ],
202
+ [
203
+ "examples/plant.png",
204
+ "Right",
205
+ "blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF",
206
+ "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white",
207
+ 0.9,
208
+ 0.5,
209
+ 1.25,
210
+ 25,
211
+ 69_420,
212
+ ],
213
+ [
214
+ "examples/plant.png",
215
+ "Left",
216
+ "floor is blue ice cavern, stalactite, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF",
217
+ "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white",
218
+ 0.9,
219
+ 0.5,
220
+ 1.25,
221
+ 25,
222
+ 69_420,
223
+ ],
224
+ ],
225
+ inputs=[
226
+ input_image,
227
+ light_pref,
228
+ prompt,
229
+ neg_prompt,
230
+ strength_first_pass,
231
+ strength_second_pass,
232
+ condition_scale,
233
+ num_inference_steps,
234
+ seed,
235
+ ],
236
+ outputs=output_image,
237
+ fn=process,
238
+ cache_examples="lazy", # type: ignore
239
+ run_on_click=False,
240
+ )
241
+
242
+ demo.launch()
src/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum, auto
2
+
3
+ import torch
4
+ from huggingface_hub import ( # pyright: ignore[reportMissingTypeStubs]
5
+ hf_hub_download, # pyright: ignore[reportUnknownVariableType]
6
+ )
7
+ from PIL import Image
8
+ from refiners.fluxion.utils import load_from_safetensors, tensor_to_image
9
+ from refiners.foundationals.clip import CLIPTextEncoderL
10
+ from refiners.foundationals.latent_diffusion import SD1UNet
11
+ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder
12
+ from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight
13
+
14
+
15
+ def load_ic_light(device: torch.device, dtype: torch.dtype) -> ICLight:
16
+ return ICLight(
17
+ patch_weights=load_from_safetensors(
18
+ path=hf_hub_download(
19
+ repo_id="refiners/sd15.ic_light.fc",
20
+ filename="model.safetensors",
21
+ revision="ea10b4403e97c786a98afdcbdf0e0fec794ea542",
22
+ ),
23
+ ),
24
+ unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
25
+ tensors_path=hf_hub_download(
26
+ repo_id="refiners/sd15.realistic_vision.v5_1.unet",
27
+ filename="model.safetensors",
28
+ revision="94f74be7adfd27bee330ea1071481c0254c29989",
29
+ )
30
+ ),
31
+ clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
32
+ tensors_path=hf_hub_download(
33
+ repo_id="refiners/sd15.realistic_vision.v5_1.text_encoder",
34
+ filename="model.safetensors",
35
+ revision="7f6fa1e870c8f197d34488e14b89e63fb8d7fd6e",
36
+ )
37
+ ),
38
+ lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
39
+ tensors_path=hf_hub_download(
40
+ repo_id="refiners/sd15.realistic_vision.v5_1.autoencoder",
41
+ filename="model.safetensors",
42
+ revision="99f089787a6e1a852a0992da1e286a19fcbbaa50",
43
+ )
44
+ ),
45
+ device=device,
46
+ dtype=dtype,
47
+ )
48
+
49
+
50
+ def resize_modulo_8(
51
+ image: Image.Image,
52
+ size: int = 768,
53
+ resample: Image.Resampling | None = None,
54
+ on_short: bool = True,
55
+ ) -> Image.Image:
56
+ """
57
+ Resize an image respecting the aspect ratio and ensuring the size is a multiple of 8.
58
+
59
+ The `on_short` parameter determines whether the resizing is based on the shortest side.
60
+ """
61
+ assert size % 8 == 0, "Size must be a multiple of 8 because this is the latent compression size."
62
+ side_size = min(image.size) if on_short else max(image.size)
63
+ scale = size / (side_size * 8)
64
+ new_size = (int(image.width * scale) * 8, int(image.height * scale) * 8)
65
+ return image.resize(new_size, resample=resample or Image.Resampling.LANCZOS)
66
+
67
+
68
+ class LightingPreference(str, Enum):
69
+ LEFT = auto()
70
+ RIGHT = auto()
71
+ TOP = auto()
72
+ BOTTOM = auto()
73
+ NONE = auto()
74
+
75
+ def get_init_image(self, width: int, height: int, interval: tuple[float, float] = (0.0, 1.0)) -> Image.Image | None:
76
+ """
77
+ Generate an image with a linear gradient based on the lighting preference.
78
+
79
+ In the original code, interval is always (0., 1.) ; we added it as a parameter to make the function more
80
+ flexible and allow for less contrasted images with a smaller interval.
81
+ see https://github.com/lllyasviel/IC-Light/blob/7886874/gradio_demo.py#L242
82
+ """
83
+ start, end = interval
84
+ match self:
85
+ case LightingPreference.LEFT:
86
+ tensor = torch.linspace(end, start, width).repeat(1, 1, height, 1)
87
+ case LightingPreference.RIGHT:
88
+ tensor = torch.linspace(start, end, width).repeat(1, 1, height, 1)
89
+ case LightingPreference.TOP:
90
+ tensor = torch.linspace(end, start, height).repeat(1, 1, width, 1).transpose(2, 3)
91
+ case LightingPreference.BOTTOM:
92
+ tensor = torch.linspace(start, end, height).repeat(1, 1, width, 1).transpose(2, 3)
93
+ case LightingPreference.NONE:
94
+ return None
95
+
96
+ return tensor_to_image(tensor).convert("RGB")
97
+
98
+ @classmethod
99
+ def from_str(cls, value: str):
100
+ match value.lower():
101
+ case "left":
102
+ return LightingPreference.LEFT
103
+ case "right":
104
+ return LightingPreference.RIGHT
105
+ case "top":
106
+ return LightingPreference.TOP
107
+ case "bottom":
108
+ return LightingPreference.BOTTOM
109
+ case "none":
110
+ return LightingPreference.NONE
111
+ case _:
112
+ raise ValueError(f"Invalid lighting preference: {value}")