RayTsai-030 commited on
Commit
0e1ee20
·
1 Parent(s): 31cffe7
app.py CHANGED
@@ -35,84 +35,20 @@ intro = """
35
  </h1>
36
  <span>[<a target="_blank" href="https://ray-1026.github.io/lightsout/">Project page</a>]</span>
37
  </div>
 
 
 
38
  """
39
 
40
-
41
- # SYSTEM_PROMPT = """
42
- # # Edit Instruction Rewriter
43
- # You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable professional-level edit instruction based on the user-provided instruction and the image to be edited.
44
- # Please strictly follow the rewriting rules below:
45
- # ## 1. General Principles
46
- # - Keep the rewritten prompt **concise**. Avoid overly long sentences and reduce unnecessary descriptive language.
47
- # - If the instruction is contradictory, vague, or unachievable, prioritize reasonable inference and correction, and supplement details when necessary.
48
- # - Keep the core intention of the original instruction unchanged, only enhancing its clarity, rationality, and visual feasibility.
49
- # - All added objects or modifications must align with the logic and style of the edited input image’s overall scene.
50
- # ## 2. Task Type Handling Rules
51
- # ### 1. Add, Delete, Replace Tasks
52
- # - If the instruction is clear (already includes task type, target entity, position, quantity, attributes), preserve the original intent and only refine the grammar.
53
- # - If the description is vague, supplement with minimal but sufficient details (category, color, size, orientation, position, etc.). For example:
54
- # > Original: "Add an animal"
55
- # > Rewritten: "Add a light-gray cat in the bottom-right corner, sitting and facing the camera"
56
- # - Remove meaningless instructions: e.g., "Add 0 objects" should be ignored or flagged as invalid.
57
- # - For replacement tasks, specify "Replace Y with X" and briefly describe the key visual features of X.
58
- # ### 2. Text Editing Tasks
59
- # - All text content must be enclosed in English double quotes `" "`. Do not translate or alter the original language of the text, and do not change the capitalization.
60
- # - **For text replacement tasks, always use the fixed template:**
61
- # - `Replace "xx" to "yy"`.
62
- # - `Replace the xx bounding box to "yy"`.
63
- # - If the user does not specify text content, infer and add concise text based on the instruction and the input image’s context. For example:
64
- # > Original: "Add a line of text" (poster)
65
- # > Rewritten: "Add text \"LIMITED EDITION\" at the top center with slight shadow"
66
- # - Specify text position, color, and layout in a concise way.
67
- # ### 3. Human Editing Tasks
68
- # - Maintain the person’s core visual consistency (ethnicity, gender, age, hairstyle, expression, outfit, etc.).
69
- # - If modifying appearance (e.g., clothes, hairstyle), ensure the new element is consistent with the original style.
70
- # - **For expression changes, they must be natural and subtle, never exaggerated.**
71
- # - If deletion is not specifically emphasized, the most important subject in the original image (e.g., a person, an animal) should be preserved.
72
- # - For background change tasks, emphasize maintaining subject consistency at first.
73
- # - Example:
74
- # > Original: "Change the person’s hat"
75
- # > Rewritten: "Replace the man’s hat with a dark brown beret; keep smile, short hair, and gray jacket unchanged"
76
- # ### 4. Style Transformation or Enhancement Tasks
77
- # - If a style is specified, describe it concisely with key visual traits. For example:
78
- # > Original: "Disco style"
79
- # > Rewritten: "1970s disco: flashing lights, disco ball, mirrored walls, colorful tones"
80
- # - If the instruction says "use reference style" or "keep current style," analyze the input image, extract main features (color, composition, texture, lighting, art style), and integrate them concisely.
81
- # - **For coloring tasks, including restoring old photos, always use the fixed template:** "Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"
82
- # - If there are other changes, place the style description at the end.
83
- # ## 3. Rationality and Logic Checks
84
- # - Resolve contradictory instructions: e.g., "Remove all trees but keep all trees" should be logically corrected.
85
- # - Add missing key information: if position is unspecified, choose a reasonable area based on composition (near subject, empty space, center/edges).
86
- # # Output Format Example
87
- # ```json
88
- # {
89
- # "Rewritten": "..."
90
- # }
91
- # """
92
-
93
-
94
- # def polish_prompt(prompt, img):
95
- # prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
96
- # success = False
97
- # while not success:
98
- # try:
99
- # result = api(prompt, [img])
100
- # # print(f"Result: {result}")
101
- # # print(f"Polished Prompt: {polished_prompt}")
102
- # if isinstance(result, str):
103
- # result = result.replace("```json", "")
104
- # result = result.replace("```", "")
105
- # result = json.loads(result)
106
- # else:
107
- # result = json.loads(result)
108
-
109
- # polished_prompt = result["Rewritten"]
110
- # polished_prompt = polished_prompt.strip()
111
- # polished_prompt = polished_prompt.replace("\n", " ")
112
- # success = True
113
- # except Exception as e:
114
- # print(f"[Warning] Error during API call: {e}")
115
- # return polished_prompt
116
 
117
 
118
  def encode_image(pil_image):
@@ -123,46 +59,13 @@ def encode_image(pil_image):
123
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
124
 
125
 
126
- # def api(prompt, img_list, model="qwen-vl-max-latest", kwargs={}):
127
- # import dashscope
128
-
129
- # api_key = os.environ.get("DASH_API_KEY")
130
- # if not api_key:
131
- # raise EnvironmentError("DASH_API_KEY is not set")
132
- # assert model in ["qwen-vl-max-latest"], f"Not implemented model {model}"
133
- # sys_promot = (
134
- # "you are a helpful assistant, you should provide useful answers to users."
135
- # )
136
- # messages = [
137
- # {"role": "system", "content": sys_promot},
138
- # {"role": "user", "content": []},
139
- # ]
140
- # for img in img_list:
141
- # messages[1]["content"].append(
142
- # {"image": f"data:image/png;base64,{encode_image(img)}"}
143
- # )
144
- # messages[1]["content"].append({"text": f"{prompt}"})
145
-
146
- # response_format = kwargs.get("response_format", None)
147
-
148
- # response = dashscope.MultiModalConversation.call(
149
- # api_key=api_key,
150
- # model=model, # For example, use qwen-plus here. You can change the model name as needed. Model list: https://help.aliyun.com/zh/model-studio/getting-started/models
151
- # messages=messages,
152
- # result_format="message",
153
- # response_format=response_format,
154
- # )
155
-
156
- # if response.status_code == 200:
157
- # return response.output.choices[0].message.content[0]["text"]
158
- # else:
159
- # raise Exception(f"Failed to post: {response}")
160
 
161
 
162
  ## --- Model Loading --- ##
163
  device = "cuda" if torch.cuda.is_available() else "cpu"
164
  dtype = torch.bfloat16
165
- print(f"Using device: {device}")
166
 
167
  # controlnet
168
  controlnet = ControlNetModel.from_pretrained(
@@ -171,7 +74,9 @@ controlnet = ControlNetModel.from_pretrained(
171
 
172
  # outpainter
173
  pipe = ControlNetOutpaintPipeline.from_pretrained(
174
- "stabilityai/stable-diffusion-2-inpainting", controlnet=controlnet, torch_dtype=dtype
 
 
175
  ).to(device)
176
  pipe.scheduler = CustomScheduler.from_config(pipe.scheduler.config)
177
  pipe.unet.load_attn_procs("./weights/light_outpaint_lora", use_safetensors=True)
@@ -189,17 +94,20 @@ blip2 = blip2.to(device)
189
 
190
  # light regressor
191
  lsr_module = LightSourceRegressor()
192
- ckpt = torch.load("./weights/light_regress/model.pth", map_location="cpu" if device=="cpu" else None)
 
 
193
  lsr_module.load_state_dict(ckpt["model"])
194
  lsr_module.to(device)
195
  lsr_module.eval()
196
 
197
  # SIFR model
198
  sifr_model = Uformer(img_size=512, img_ch=3, output_ch=6).to(device)
199
- sifr_model.load_state_dict(torch.load("./weights/net_g_last.pth", map_location="cpu" if device=="cpu" else None))
200
-
201
- # --- UI Constants and Helpers ---
202
- MAX_SEED = np.iinfo(np.int32).max
 
203
 
204
 
205
  # --- Main Inference Function (with hardcoded negative prompt) ---
@@ -220,7 +128,9 @@ def infer(
220
  Generates an image
221
  """
222
  # dataset
223
- dataset = HFCustomImageLoader(image, left_outpaint, right_outpaint, up_outpaint, down_outpaint)
 
 
224
  data = dataset[0]
225
 
226
  # generator
@@ -255,9 +165,7 @@ def infer(
255
  pred_mask = pred_mask.cpu()
256
  pred_mask = pred_mask.numpy()
257
 
258
- data["control_img"] = Image.fromarray(
259
- (pred_mask[0, 0] * 255).astype(np.uint8)
260
- )
261
 
262
  print("Finish light source detection...")
263
 
@@ -337,21 +245,12 @@ def infer(
337
 
338
  print("Finish flare removal...")
339
 
340
- return outpaint_result, deflare_result
341
 
342
 
343
  # --- Examples and UI Layout ---
344
  examples = []
345
 
346
- css = """
347
- #col-container {
348
- margin: 0 auto;
349
- max-width: 1024px;
350
- }
351
- #edit_text{
352
- margin-top: -62px !important
353
- }
354
- """
355
 
356
  with gr.Blocks(css=css) as demo:
357
  with gr.Column(elem_id="col-container"):
@@ -378,16 +277,28 @@ with gr.Blocks(css=css) as demo:
378
 
379
  with gr.Column():
380
  left_outpaint = gr.Slider(
381
- label="Left outpaint (px)", minimum=32, maximum=128, step=32, value=64
 
 
 
 
382
  )
383
  right_outpaint = gr.Slider(
384
- label="Right outpaint (px)", minimum=32, maximum=128, step=32, value=64
 
 
 
 
385
  )
386
  up_outpaint = gr.Slider(
387
  label="Up outpaint (px)", minimum=32, maximum=128, step=32, value=64
388
  )
389
  down_outpaint = gr.Slider(
390
- label="Down outpaint (px)", minimum=32, maximum=128, step=32, value=64
 
 
 
 
391
  )
392
 
393
  # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
@@ -418,6 +329,10 @@ with gr.Blocks(css=css) as demo:
418
  )
419
 
420
  with gr.Row():
 
 
 
 
421
  with gr.Column():
422
  outpainted_result = gr.Image(
423
  label="Outpainted Result", show_label=True, type="pil"
@@ -458,7 +373,7 @@ with gr.Blocks(css=css) as demo:
458
  up_outpaint,
459
  down_outpaint,
460
  ],
461
- outputs=[outpainted_result, flarefree_result],
462
  )
463
 
464
  if __name__ == "__main__":
 
35
  </h1>
36
  <span>[<a target="_blank" href="https://ray-1026.github.io/lightsout/">Project page</a>]</span>
37
  </div>
38
+ <div style="text-align: center; margin-top: 15px; font-size: 1.2em;">
39
+ <strong>NOTICE</strong>: This demo is limited to cpu inference only. For better experience, please run the code locally with a GPU.
40
+ </div>
41
  """
42
 
43
+ css = """
44
+ #col-container {
45
+ margin: 0 auto;
46
+ max-width: 1024px;
47
+ }
48
+ #edit_text{
49
+ margin-top: -62px !important
50
+ }
51
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  def encode_image(pil_image):
 
59
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
60
 
61
 
62
+ # --- UI Constants and Helpers ---
63
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  ## --- Model Loading --- ##
67
  device = "cuda" if torch.cuda.is_available() else "cpu"
68
  dtype = torch.bfloat16
 
69
 
70
  # controlnet
71
  controlnet = ControlNetModel.from_pretrained(
 
74
 
75
  # outpainter
76
  pipe = ControlNetOutpaintPipeline.from_pretrained(
77
+ "stabilityai/stable-diffusion-2-inpainting",
78
+ controlnet=controlnet,
79
+ torch_dtype=dtype,
80
  ).to(device)
81
  pipe.scheduler = CustomScheduler.from_config(pipe.scheduler.config)
82
  pipe.unet.load_attn_procs("./weights/light_outpaint_lora", use_safetensors=True)
 
94
 
95
  # light regressor
96
  lsr_module = LightSourceRegressor()
97
+ ckpt = torch.load(
98
+ "./weights/light_regress/model.pth", map_location="cpu" if device == "cpu" else None
99
+ )
100
  lsr_module.load_state_dict(ckpt["model"])
101
  lsr_module.to(device)
102
  lsr_module.eval()
103
 
104
  # SIFR model
105
  sifr_model = Uformer(img_size=512, img_ch=3, output_ch=6).to(device)
106
+ sifr_model.load_state_dict(
107
+ torch.load(
108
+ "./weights/net_g_last.pth", map_location="cpu" if device == "cpu" else None
109
+ )
110
+ )
111
 
112
 
113
  # --- Main Inference Function (with hardcoded negative prompt) ---
 
128
  Generates an image
129
  """
130
  # dataset
131
+ dataset = HFCustomImageLoader(
132
+ image, left_outpaint, right_outpaint, up_outpaint, down_outpaint
133
+ )
134
  data = dataset[0]
135
 
136
  # generator
 
165
  pred_mask = pred_mask.cpu()
166
  pred_mask = pred_mask.numpy()
167
 
168
+ data["control_img"] = Image.fromarray((pred_mask[0, 0] * 255).astype(np.uint8))
 
 
169
 
170
  print("Finish light source detection...")
171
 
 
245
 
246
  print("Finish flare removal...")
247
 
248
+ return data["control_img"], outpaint_result, deflare_result
249
 
250
 
251
  # --- Examples and UI Layout ---
252
  examples = []
253
 
 
 
 
 
 
 
 
 
 
254
 
255
  with gr.Blocks(css=css) as demo:
256
  with gr.Column(elem_id="col-container"):
 
277
 
278
  with gr.Column():
279
  left_outpaint = gr.Slider(
280
+ label="Left outpaint (px)",
281
+ minimum=32,
282
+ maximum=128,
283
+ step=32,
284
+ value=64,
285
  )
286
  right_outpaint = gr.Slider(
287
+ label="Right outpaint (px)",
288
+ minimum=32,
289
+ maximum=128,
290
+ step=32,
291
+ value=64,
292
  )
293
  up_outpaint = gr.Slider(
294
  label="Up outpaint (px)", minimum=32, maximum=128, step=32, value=64
295
  )
296
  down_outpaint = gr.Slider(
297
+ label="Down outpaint (px)",
298
+ minimum=32,
299
+ maximum=128,
300
+ step=32,
301
+ value=64,
302
  )
303
 
304
  # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
329
  )
330
 
331
  with gr.Row():
332
+ with gr.Column():
333
+ lightmask_result = gr.Image(
334
+ label="Lightmask Result", show_label=True, type="pil"
335
+ )
336
  with gr.Column():
337
  outpainted_result = gr.Image(
338
  label="Outpainted Result", show_label=True, type="pil"
 
373
  up_outpaint,
374
  down_outpaint,
375
  ],
376
+ outputs=[lightmask_result, outpainted_result, flarefree_result],
377
  )
378
 
379
  if __name__ == "__main__":
src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc CHANGED
Binary files a/src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc and b/src/pipelines/__pycache__/pipeline_stable_diffusion_outpaint.cpython-39.pyc differ