wjs0725 commited on
Commit
1e3cd91
1 Parent(s): 74a9853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -12
app.py CHANGED
@@ -38,9 +38,6 @@ def encode(init_image, torch_device, ae):
38
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
39
  init_image = init_image.unsqueeze(0)
40
  init_image = init_image.to(torch_device)
41
- print("!!!!!!!init_image!!!!!!",init_image.device)
42
- print("!!!!!!!ae!!!!!!",next(ae.parameters()).device)
43
-
44
  with torch.no_grad():
45
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
46
  return init_image
@@ -65,20 +62,22 @@ class FluxEditor:
65
  # init all components
66
  self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
67
  self.clip = load_clip(self.device)
68
- self.model = load_flow_model(self.name, device='cuda')
69
- self.ae = load_ae(self.name, device='cuda')
70
  self.t5.eval()
71
  self.clip.eval()
72
  self.ae.eval()
73
  self.model.eval()
74
- self.t5.cuda()
75
- self.clip.cuda()
76
- self.ae.cuda()
77
- self.model.cuda()
 
78
 
79
  @torch.inference_mode()
80
  @spaces.GPU(duration=60)
81
  def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
 
82
  seed = None
83
  # if seed == -1:
84
  # seed = None
@@ -112,6 +111,11 @@ class FluxEditor:
112
  t0 = time.perf_counter()
113
 
114
  opts.seed = None
 
 
 
 
 
115
  #############inverse#######################
116
  info = {}
117
  info['feature'] = {}
@@ -125,6 +129,12 @@ class FluxEditor:
125
  inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
126
  timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
127
 
 
 
 
 
 
 
128
  # inversion initial noise
129
  with torch.no_grad():
130
  z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
@@ -136,6 +146,12 @@ class FluxEditor:
136
  # denoise initial noise
137
  x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
138
 
 
 
 
 
 
 
139
  # decode latents to pixel space
140
  x = unpack(x.float(), opts.width, opts.height)
141
 
@@ -171,7 +187,7 @@ class FluxEditor:
171
  exif_data[ExifTags.Base.Model] = self.name
172
  if self.add_sampling_metadata:
173
  exif_data[ExifTags.Base.ImageDescription] = source_prompt
174
- # img.save(fn, exif=exif_data, quality=95, subsampling=0)
175
 
176
 
177
  print("End Edit")
@@ -226,5 +242,5 @@ if __name__ == "__main__":
226
  parser.add_argument("--port", type=int, default=41035)
227
  args = parser.parse_args()
228
 
229
- demo = create_demo("flux-dev", "cuda", False)
230
- demo.launch()
 
38
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
39
  init_image = init_image.unsqueeze(0)
40
  init_image = init_image.to(torch_device)
 
 
 
41
  with torch.no_grad():
42
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
43
  return init_image
 
62
  # init all components
63
  self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
64
  self.clip = load_clip(self.device)
65
+ self.model = load_flow_model(self.name, device="cpu" if self.offload else self.device)
66
+ self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
67
  self.t5.eval()
68
  self.clip.eval()
69
  self.ae.eval()
70
  self.model.eval()
71
+
72
+ if self.offload:
73
+ self.model.cpu()
74
+ torch.cuda.empty_cache()
75
+ self.ae.encoder.to(self.device)
76
 
77
  @torch.inference_mode()
78
  @spaces.GPU(duration=60)
79
  def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
80
+ torch.cuda.empty_cache()
81
  seed = None
82
  # if seed == -1:
83
  # seed = None
 
111
  t0 = time.perf_counter()
112
 
113
  opts.seed = None
114
+ if self.offload:
115
+ self.ae = self.ae.cpu()
116
+ torch.cuda.empty_cache()
117
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
118
+
119
  #############inverse#######################
120
  info = {}
121
  info['feature'] = {}
 
129
  inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
130
  timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
131
 
132
+ # offload TEs to CPU, load model to gpu
133
+ if self.offload:
134
+ self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
135
+ torch.cuda.empty_cache()
136
+ self.model = self.model.to(self.device)
137
+
138
  # inversion initial noise
139
  with torch.no_grad():
140
  z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
 
146
  # denoise initial noise
147
  x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
148
 
149
+ # offload model, load autoencoder to gpu
150
+ if self.offload:
151
+ self.model.cpu()
152
+ torch.cuda.empty_cache()
153
+ self.ae.decoder.to(x.device)
154
+
155
  # decode latents to pixel space
156
  x = unpack(x.float(), opts.width, opts.height)
157
 
 
187
  exif_data[ExifTags.Base.Model] = self.name
188
  if self.add_sampling_metadata:
189
  exif_data[ExifTags.Base.ImageDescription] = source_prompt
190
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
191
 
192
 
193
  print("End Edit")
 
242
  parser.add_argument("--port", type=int, default=41035)
243
  args = parser.parse_args()
244
 
245
+ demo = create_demo(args.name, args.device)
246
+ demo.launch()