Gemini899 commited on
Commit
b1bb2b0
·
verified ·
1 Parent(s): 4499056

Update flux1_img2img.py

Browse files
Files changed (1) hide show
  1. flux1_img2img.py +30 -18
flux1_img2img.py CHANGED
@@ -4,6 +4,15 @@ from PIL import Image
4
  import sys
5
  import spaces
6
 
 
 
 
 
 
 
 
 
 
7
  @spaces.GPU
8
  def process_image(
9
  image,
@@ -19,27 +28,32 @@ def process_image(
19
  print("empty input image returned")
20
  return None
21
 
22
- # 1) Use float16 (T4 doesn't have native bf16 support)
23
- # 2) low_cpu_mem_usage=True for more efficient loading
24
- # 3) Optionally enable xFormers
 
25
  pipe = FluxImg2ImgPipeline.from_pretrained(
26
  model_id,
27
- torch_dtype=torch.float16,
28
- revision="fp16", # sometimes needed if the repo has an FP16 branch
29
- low_cpu_mem_usage=True
30
- )
31
-
32
- # Move to GPU
33
- pipe.to("cuda")
34
 
35
- # If you have xFormers installed (pip install xformers):
36
  try:
37
  pipe.enable_xformers_memory_efficient_attention()
38
  print("Enabled xFormers memory efficient attention.")
39
  except Exception as e:
40
- print("xFormers not available:", e)
 
 
 
 
 
 
 
 
 
 
41
 
42
- # Create a reproducible generator
43
  generator = torch.Generator("cuda").manual_seed(seed)
44
 
45
  print(f"Prompt: {prompt}")
@@ -48,18 +62,16 @@ def process_image(
48
  image=image,
49
  generator=generator,
50
  strength=strength,
51
- guidance_scale=0, # same as your original code
52
  num_inference_steps=num_inference_steps,
53
  max_sequence_length=256
54
  )
55
 
56
- # TODO: support mask if needed
57
  return output.images[0]
58
 
59
  if __name__ == "__main__":
60
- # Usage: python img2img.py input_image.png input_mask.png output.png
61
  image = Image.open(sys.argv[1]).convert("RGB")
62
- mask = Image.open(sys.argv[2]).convert("RGB") # currently unused
63
  result = process_image(image, mask)
64
- if result is not None:
65
  result.save(sys.argv[3])
 
4
  import sys
5
  import spaces
6
 
7
+ def resize_image(image, max_res=512):
8
+ w, h = image.size
9
+ ratio = min(max_res / w, max_res / h)
10
+ if ratio < 1.0:
11
+ new_w = int(w * ratio)
12
+ new_h = int(h * ratio)
13
+ image = image.resize((new_w, new_h), Image.LANCZOS)
14
+ return image
15
+
16
  @spaces.GPU
17
  def process_image(
18
  image,
 
28
  print("empty input image returned")
29
  return None
30
 
31
+ # Try resizing input to reduce VRAM usage
32
+ image = resize_image(image, 512)
33
+
34
+ # Load with float16
35
  pipe = FluxImg2ImgPipeline.from_pretrained(
36
  model_id,
37
+ torch_dtype=torch.float16
38
+ ).to("cuda")
 
 
 
 
 
39
 
40
+ # If xFormers installed, enable memory efficient attention
41
  try:
42
  pipe.enable_xformers_memory_efficient_attention()
43
  print("Enabled xFormers memory efficient attention.")
44
  except Exception as e:
45
+ print("Could not enable xFormers:", e)
46
+
47
+ # Enable CPU offload to reduce VRAM usage
48
+ # (Pick either model_cpu_offload or sequential_cpu_offload)
49
+ try:
50
+ pipe.enable_model_cpu_offload()
51
+ except Exception as e:
52
+ print("Could not enable model_cpu_offload:", e)
53
+
54
+ # Optional: enable VAE slicing
55
+ pipe.enable_vae_slicing()
56
 
 
57
  generator = torch.Generator("cuda").manual_seed(seed)
58
 
59
  print(f"Prompt: {prompt}")
 
62
  image=image,
63
  generator=generator,
64
  strength=strength,
65
+ guidance_scale=0,
66
  num_inference_steps=num_inference_steps,
67
  max_sequence_length=256
68
  )
69
 
 
70
  return output.images[0]
71
 
72
  if __name__ == "__main__":
 
73
  image = Image.open(sys.argv[1]).convert("RGB")
74
+ mask = Image.open(sys.argv[2]).convert("RGB") # unused
75
  result = process_image(image, mask)
76
+ if result:
77
  result.save(sys.argv[3])