Spaces:
Runtime error
Runtime error
Commit
·
0e1ee20
1
Parent(s):
31cffe7
update
Browse files
app.py
CHANGED
|
@@ -35,84 +35,20 @@ intro = """
|
|
| 35 |
</h1>
|
| 36 |
<span>[<a target="_blank" href="https://ray-1026.github.io/lightsout/">Project page</a>]</span>
|
| 37 |
</div>
|
|
|
|
|
|
|
|
|
|
| 38 |
"""
|
| 39 |
|
| 40 |
-
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# - All added objects or modifications must align with the logic and style of the edited input image’s overall scene.
|
| 50 |
-
# ## 2. Task Type Handling Rules
|
| 51 |
-
# ### 1. Add, Delete, Replace Tasks
|
| 52 |
-
# - If the instruction is clear (already includes task type, target entity, position, quantity, attributes), preserve the original intent and only refine the grammar.
|
| 53 |
-
# - If the description is vague, supplement with minimal but sufficient details (category, color, size, orientation, position, etc.). For example:
|
| 54 |
-
# > Original: "Add an animal"
|
| 55 |
-
# > Rewritten: "Add a light-gray cat in the bottom-right corner, sitting and facing the camera"
|
| 56 |
-
# - Remove meaningless instructions: e.g., "Add 0 objects" should be ignored or flagged as invalid.
|
| 57 |
-
# - For replacement tasks, specify "Replace Y with X" and briefly describe the key visual features of X.
|
| 58 |
-
# ### 2. Text Editing Tasks
|
| 59 |
-
# - All text content must be enclosed in English double quotes `" "`. Do not translate or alter the original language of the text, and do not change the capitalization.
|
| 60 |
-
# - **For text replacement tasks, always use the fixed template:**
|
| 61 |
-
# - `Replace "xx" to "yy"`.
|
| 62 |
-
# - `Replace the xx bounding box to "yy"`.
|
| 63 |
-
# - If the user does not specify text content, infer and add concise text based on the instruction and the input image’s context. For example:
|
| 64 |
-
# > Original: "Add a line of text" (poster)
|
| 65 |
-
# > Rewritten: "Add text \"LIMITED EDITION\" at the top center with slight shadow"
|
| 66 |
-
# - Specify text position, color, and layout in a concise way.
|
| 67 |
-
# ### 3. Human Editing Tasks
|
| 68 |
-
# - Maintain the person’s core visual consistency (ethnicity, gender, age, hairstyle, expression, outfit, etc.).
|
| 69 |
-
# - If modifying appearance (e.g., clothes, hairstyle), ensure the new element is consistent with the original style.
|
| 70 |
-
# - **For expression changes, they must be natural and subtle, never exaggerated.**
|
| 71 |
-
# - If deletion is not specifically emphasized, the most important subject in the original image (e.g., a person, an animal) should be preserved.
|
| 72 |
-
# - For background change tasks, emphasize maintaining subject consistency at first.
|
| 73 |
-
# - Example:
|
| 74 |
-
# > Original: "Change the person’s hat"
|
| 75 |
-
# > Rewritten: "Replace the man’s hat with a dark brown beret; keep smile, short hair, and gray jacket unchanged"
|
| 76 |
-
# ### 4. Style Transformation or Enhancement Tasks
|
| 77 |
-
# - If a style is specified, describe it concisely with key visual traits. For example:
|
| 78 |
-
# > Original: "Disco style"
|
| 79 |
-
# > Rewritten: "1970s disco: flashing lights, disco ball, mirrored walls, colorful tones"
|
| 80 |
-
# - If the instruction says "use reference style" or "keep current style," analyze the input image, extract main features (color, composition, texture, lighting, art style), and integrate them concisely.
|
| 81 |
-
# - **For coloring tasks, including restoring old photos, always use the fixed template:** "Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"
|
| 82 |
-
# - If there are other changes, place the style description at the end.
|
| 83 |
-
# ## 3. Rationality and Logic Checks
|
| 84 |
-
# - Resolve contradictory instructions: e.g., "Remove all trees but keep all trees" should be logically corrected.
|
| 85 |
-
# - Add missing key information: if position is unspecified, choose a reasonable area based on composition (near subject, empty space, center/edges).
|
| 86 |
-
# # Output Format Example
|
| 87 |
-
# ```json
|
| 88 |
-
# {
|
| 89 |
-
# "Rewritten": "..."
|
| 90 |
-
# }
|
| 91 |
-
# """
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# def polish_prompt(prompt, img):
|
| 95 |
-
# prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
|
| 96 |
-
# success = False
|
| 97 |
-
# while not success:
|
| 98 |
-
# try:
|
| 99 |
-
# result = api(prompt, [img])
|
| 100 |
-
# # print(f"Result: {result}")
|
| 101 |
-
# # print(f"Polished Prompt: {polished_prompt}")
|
| 102 |
-
# if isinstance(result, str):
|
| 103 |
-
# result = result.replace("```json", "")
|
| 104 |
-
# result = result.replace("```", "")
|
| 105 |
-
# result = json.loads(result)
|
| 106 |
-
# else:
|
| 107 |
-
# result = json.loads(result)
|
| 108 |
-
|
| 109 |
-
# polished_prompt = result["Rewritten"]
|
| 110 |
-
# polished_prompt = polished_prompt.strip()
|
| 111 |
-
# polished_prompt = polished_prompt.replace("\n", " ")
|
| 112 |
-
# success = True
|
| 113 |
-
# except Exception as e:
|
| 114 |
-
# print(f"[Warning] Error during API call: {e}")
|
| 115 |
-
# return polished_prompt
|
| 116 |
|
| 117 |
|
| 118 |
def encode_image(pil_image):
|
|
@@ -123,46 +59,13 @@ def encode_image(pil_image):
|
|
| 123 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 124 |
|
| 125 |
|
| 126 |
-
#
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
# api_key = os.environ.get("DASH_API_KEY")
|
| 130 |
-
# if not api_key:
|
| 131 |
-
# raise EnvironmentError("DASH_API_KEY is not set")
|
| 132 |
-
# assert model in ["qwen-vl-max-latest"], f"Not implemented model {model}"
|
| 133 |
-
# sys_promot = (
|
| 134 |
-
# "you are a helpful assistant, you should provide useful answers to users."
|
| 135 |
-
# )
|
| 136 |
-
# messages = [
|
| 137 |
-
# {"role": "system", "content": sys_promot},
|
| 138 |
-
# {"role": "user", "content": []},
|
| 139 |
-
# ]
|
| 140 |
-
# for img in img_list:
|
| 141 |
-
# messages[1]["content"].append(
|
| 142 |
-
# {"image": f"data:image/png;base64,{encode_image(img)}"}
|
| 143 |
-
# )
|
| 144 |
-
# messages[1]["content"].append({"text": f"{prompt}"})
|
| 145 |
-
|
| 146 |
-
# response_format = kwargs.get("response_format", None)
|
| 147 |
-
|
| 148 |
-
# response = dashscope.MultiModalConversation.call(
|
| 149 |
-
# api_key=api_key,
|
| 150 |
-
# model=model, # For example, use qwen-plus here. You can change the model name as needed. Model list: https://help.aliyun.com/zh/model-studio/getting-started/models
|
| 151 |
-
# messages=messages,
|
| 152 |
-
# result_format="message",
|
| 153 |
-
# response_format=response_format,
|
| 154 |
-
# )
|
| 155 |
-
|
| 156 |
-
# if response.status_code == 200:
|
| 157 |
-
# return response.output.choices[0].message.content[0]["text"]
|
| 158 |
-
# else:
|
| 159 |
-
# raise Exception(f"Failed to post: {response}")
|
| 160 |
|
| 161 |
|
| 162 |
## --- Model Loading --- ##
|
| 163 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 164 |
dtype = torch.bfloat16
|
| 165 |
-
print(f"Using device: {device}")
|
| 166 |
|
| 167 |
# controlnet
|
| 168 |
controlnet = ControlNetModel.from_pretrained(
|
|
@@ -171,7 +74,9 @@ controlnet = ControlNetModel.from_pretrained(
|
|
| 171 |
|
| 172 |
# outpainter
|
| 173 |
pipe = ControlNetOutpaintPipeline.from_pretrained(
|
| 174 |
-
"stabilityai/stable-diffusion-2-inpainting",
|
|
|
|
|
|
|
| 175 |
).to(device)
|
| 176 |
pipe.scheduler = CustomScheduler.from_config(pipe.scheduler.config)
|
| 177 |
pipe.unet.load_attn_procs("./weights/light_outpaint_lora", use_safetensors=True)
|
|
@@ -189,17 +94,20 @@ blip2 = blip2.to(device)
|
|
| 189 |
|
| 190 |
# light regressor
|
| 191 |
lsr_module = LightSourceRegressor()
|
| 192 |
-
ckpt = torch.load(
|
|
|
|
|
|
|
| 193 |
lsr_module.load_state_dict(ckpt["model"])
|
| 194 |
lsr_module.to(device)
|
| 195 |
lsr_module.eval()
|
| 196 |
|
| 197 |
# SIFR model
|
| 198 |
sifr_model = Uformer(img_size=512, img_ch=3, output_ch=6).to(device)
|
| 199 |
-
sifr_model.load_state_dict(
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
# --- Main Inference Function (with hardcoded negative prompt) ---
|
|
@@ -220,7 +128,9 @@ def infer(
|
|
| 220 |
Generates an image
|
| 221 |
"""
|
| 222 |
# dataset
|
| 223 |
-
dataset = HFCustomImageLoader(
|
|
|
|
|
|
|
| 224 |
data = dataset[0]
|
| 225 |
|
| 226 |
# generator
|
|
@@ -255,9 +165,7 @@ def infer(
|
|
| 255 |
pred_mask = pred_mask.cpu()
|
| 256 |
pred_mask = pred_mask.numpy()
|
| 257 |
|
| 258 |
-
data["control_img"] = Image.fromarray(
|
| 259 |
-
(pred_mask[0, 0] * 255).astype(np.uint8)
|
| 260 |
-
)
|
| 261 |
|
| 262 |
print("Finish light source detection...")
|
| 263 |
|
|
@@ -337,21 +245,12 @@ def infer(
|
|
| 337 |
|
| 338 |
print("Finish flare removal...")
|
| 339 |
|
| 340 |
-
return outpaint_result, deflare_result
|
| 341 |
|
| 342 |
|
| 343 |
# --- Examples and UI Layout ---
|
| 344 |
examples = []
|
| 345 |
|
| 346 |
-
css = """
|
| 347 |
-
#col-container {
|
| 348 |
-
margin: 0 auto;
|
| 349 |
-
max-width: 1024px;
|
| 350 |
-
}
|
| 351 |
-
#edit_text{
|
| 352 |
-
margin-top: -62px !important
|
| 353 |
-
}
|
| 354 |
-
"""
|
| 355 |
|
| 356 |
with gr.Blocks(css=css) as demo:
|
| 357 |
with gr.Column(elem_id="col-container"):
|
|
@@ -378,16 +277,28 @@ with gr.Blocks(css=css) as demo:
|
|
| 378 |
|
| 379 |
with gr.Column():
|
| 380 |
left_outpaint = gr.Slider(
|
| 381 |
-
label="Left outpaint (px)",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
)
|
| 383 |
right_outpaint = gr.Slider(
|
| 384 |
-
label="Right outpaint (px)",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
)
|
| 386 |
up_outpaint = gr.Slider(
|
| 387 |
label="Up outpaint (px)", minimum=32, maximum=128, step=32, value=64
|
| 388 |
)
|
| 389 |
down_outpaint = gr.Slider(
|
| 390 |
-
label="Down outpaint (px)",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
)
|
| 392 |
|
| 393 |
# randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
@@ -418,6 +329,10 @@ with gr.Blocks(css=css) as demo:
|
|
| 418 |
)
|
| 419 |
|
| 420 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
with gr.Column():
|
| 422 |
outpainted_result = gr.Image(
|
| 423 |
label="Outpainted Result", show_label=True, type="pil"
|
|
@@ -458,7 +373,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 458 |
up_outpaint,
|
| 459 |
down_outpaint,
|
| 460 |
],
|
| 461 |
-
outputs=[outpainted_result, flarefree_result],
|
| 462 |
)
|
| 463 |
|
| 464 |
if __name__ == "__main__":
|
|
|
|
| 35 |
</h1>
|
| 36 |
<span>[<a target="_blank" href="https://ray-1026.github.io/lightsout/">Project page</a>]</span>
|
| 37 |
</div>
|
| 38 |
+
<div style="text-align: center; margin-top: 15px; font-size: 1.2em;">
|
| 39 |
+
<strong>NOTICE</strong>: This demo is limited to cpu inference only. For better experience, please run the code locally with a GPU.
|
| 40 |
+
</div>
|
| 41 |
"""
|
| 42 |
|
| 43 |
+
css = """
|
| 44 |
+
#col-container {
|
| 45 |
+
margin: 0 auto;
|
| 46 |
+
max-width: 1024px;
|
| 47 |
+
}
|
| 48 |
+
#edit_text{
|
| 49 |
+
margin-top: -62px !important
|
| 50 |
+
}
|
| 51 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def encode_image(pil_image):
|
|
|
|
| 59 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 60 |
|
| 61 |
|
| 62 |
+
# --- UI Constants and Helpers ---
|
| 63 |
+
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
## --- Model Loading --- ##
|
| 67 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
dtype = torch.bfloat16
|
|
|
|
| 69 |
|
| 70 |
# controlnet
|
| 71 |
controlnet = ControlNetModel.from_pretrained(
|
|
|
|
| 74 |
|
| 75 |
# outpainter
|
| 76 |
pipe = ControlNetOutpaintPipeline.from_pretrained(
|
| 77 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
| 78 |
+
controlnet=controlnet,
|
| 79 |
+
torch_dtype=dtype,
|
| 80 |
).to(device)
|
| 81 |
pipe.scheduler = CustomScheduler.from_config(pipe.scheduler.config)
|
| 82 |
pipe.unet.load_attn_procs("./weights/light_outpaint_lora", use_safetensors=True)
|
|
|
|
| 94 |
|
| 95 |
# light regressor
|
| 96 |
lsr_module = LightSourceRegressor()
|
| 97 |
+
ckpt = torch.load(
|
| 98 |
+
"./weights/light_regress/model.pth", map_location="cpu" if device == "cpu" else None
|
| 99 |
+
)
|
| 100 |
lsr_module.load_state_dict(ckpt["model"])
|
| 101 |
lsr_module.to(device)
|
| 102 |
lsr_module.eval()
|
| 103 |
|
| 104 |
# SIFR model
|
| 105 |
sifr_model = Uformer(img_size=512, img_ch=3, output_ch=6).to(device)
|
| 106 |
+
sifr_model.load_state_dict(
|
| 107 |
+
torch.load(
|
| 108 |
+
"./weights/net_g_last.pth", map_location="cpu" if device == "cpu" else None
|
| 109 |
+
)
|
| 110 |
+
)
|
| 111 |
|
| 112 |
|
| 113 |
# --- Main Inference Function (with hardcoded negative prompt) ---
|
|
|
|
| 128 |
Generates an image
|
| 129 |
"""
|
| 130 |
# dataset
|
| 131 |
+
dataset = HFCustomImageLoader(
|
| 132 |
+
image, left_outpaint, right_outpaint, up_outpaint, down_outpaint
|
| 133 |
+
)
|
| 134 |
data = dataset[0]
|
| 135 |
|
| 136 |
# generator
|
|
|
|
| 165 |
pred_mask = pred_mask.cpu()
|
| 166 |
pred_mask = pred_mask.numpy()
|
| 167 |
|
| 168 |
+
data["control_img"] = Image.fromarray((pred_mask[0, 0] * 255).astype(np.uint8))
|
|
|
|
|
|
|
| 169 |
|
| 170 |
print("Finish light source detection...")
|
| 171 |
|
|
|
|
| 245 |
|
| 246 |
print("Finish flare removal...")
|
| 247 |
|
| 248 |
+
return data["control_img"], outpaint_result, deflare_result
|
| 249 |
|
| 250 |
|
| 251 |
# --- Examples and UI Layout ---
|
| 252 |
examples = []
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
with gr.Blocks(css=css) as demo:
|
| 256 |
with gr.Column(elem_id="col-container"):
|
|
|
|
| 277 |
|
| 278 |
with gr.Column():
|
| 279 |
left_outpaint = gr.Slider(
|
| 280 |
+
label="Left outpaint (px)",
|
| 281 |
+
minimum=32,
|
| 282 |
+
maximum=128,
|
| 283 |
+
step=32,
|
| 284 |
+
value=64,
|
| 285 |
)
|
| 286 |
right_outpaint = gr.Slider(
|
| 287 |
+
label="Right outpaint (px)",
|
| 288 |
+
minimum=32,
|
| 289 |
+
maximum=128,
|
| 290 |
+
step=32,
|
| 291 |
+
value=64,
|
| 292 |
)
|
| 293 |
up_outpaint = gr.Slider(
|
| 294 |
label="Up outpaint (px)", minimum=32, maximum=128, step=32, value=64
|
| 295 |
)
|
| 296 |
down_outpaint = gr.Slider(
|
| 297 |
+
label="Down outpaint (px)",
|
| 298 |
+
minimum=32,
|
| 299 |
+
maximum=128,
|
| 300 |
+
step=32,
|
| 301 |
+
value=64,
|
| 302 |
)
|
| 303 |
|
| 304 |
# randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
|
|
| 329 |
)
|
| 330 |
|
| 331 |
with gr.Row():
|
| 332 |
+
with gr.Column():
|
| 333 |
+
lightmask_result = gr.Image(
|
| 334 |
+
label="Lightmask Result", show_label=True, type="pil"
|
| 335 |
+
)
|
| 336 |
with gr.Column():
|
| 337 |
outpainted_result = gr.Image(
|
| 338 |
label="Outpainted Result", show_label=True, type="pil"
|
|
|
|
| 373 |
up_outpaint,
|
| 374 |
down_outpaint,
|
| 375 |
],
|
| 376 |
+
outputs=[lightmask_result, outpainted_result, flarefree_result],
|
| 377 |
)
|
| 378 |
|
| 379 |
if __name__ == "__main__":
|
src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc
CHANGED
|
Binary files a/src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc and b/src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc differ
|
|
|