Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -38,9 +38,6 @@ def encode(init_image, torch_device, ae):
|
|
38 |
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
|
39 |
init_image = init_image.unsqueeze(0)
|
40 |
init_image = init_image.to(torch_device)
|
41 |
-
print("!!!!!!!init_image!!!!!!",init_image.device)
|
42 |
-
print("!!!!!!!ae!!!!!!",next(ae.parameters()).device)
|
43 |
-
|
44 |
with torch.no_grad():
|
45 |
init_image = ae.encode(init_image.to()).to(torch.bfloat16)
|
46 |
return init_image
|
@@ -65,20 +62,22 @@ class FluxEditor:
|
|
65 |
# init all components
|
66 |
self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
|
67 |
self.clip = load_clip(self.device)
|
68 |
-
self.model = load_flow_model(self.name, device=
|
69 |
-
self.ae = load_ae(self.name, device=
|
70 |
self.t5.eval()
|
71 |
self.clip.eval()
|
72 |
self.ae.eval()
|
73 |
self.model.eval()
|
74 |
-
|
75 |
-
self.
|
76 |
-
|
77 |
-
|
|
|
78 |
|
79 |
@torch.inference_mode()
|
80 |
@spaces.GPU(duration=60)
|
81 |
def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
|
|
|
82 |
seed = None
|
83 |
# if seed == -1:
|
84 |
# seed = None
|
@@ -112,6 +111,11 @@ class FluxEditor:
|
|
112 |
t0 = time.perf_counter()
|
113 |
|
114 |
opts.seed = None
|
|
|
|
|
|
|
|
|
|
|
115 |
#############inverse#######################
|
116 |
info = {}
|
117 |
info['feature'] = {}
|
@@ -125,6 +129,12 @@ class FluxEditor:
|
|
125 |
inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
|
126 |
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
# inversion initial noise
|
129 |
with torch.no_grad():
|
130 |
z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
|
@@ -136,6 +146,12 @@ class FluxEditor:
|
|
136 |
# denoise initial noise
|
137 |
x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
# decode latents to pixel space
|
140 |
x = unpack(x.float(), opts.width, opts.height)
|
141 |
|
@@ -171,7 +187,7 @@ class FluxEditor:
|
|
171 |
exif_data[ExifTags.Base.Model] = self.name
|
172 |
if self.add_sampling_metadata:
|
173 |
exif_data[ExifTags.Base.ImageDescription] = source_prompt
|
174 |
-
|
175 |
|
176 |
|
177 |
print("End Edit")
|
@@ -226,5 +242,5 @@ if __name__ == "__main__":
|
|
226 |
parser.add_argument("--port", type=int, default=41035)
|
227 |
args = parser.parse_args()
|
228 |
|
229 |
-
demo = create_demo(
|
230 |
-
demo.launch()
|
|
|
38 |
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
|
39 |
init_image = init_image.unsqueeze(0)
|
40 |
init_image = init_image.to(torch_device)
|
|
|
|
|
|
|
41 |
with torch.no_grad():
|
42 |
init_image = ae.encode(init_image.to()).to(torch.bfloat16)
|
43 |
return init_image
|
|
|
62 |
# init all components
|
63 |
self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
|
64 |
self.clip = load_clip(self.device)
|
65 |
+
self.model = load_flow_model(self.name, device="cpu" if self.offload else self.device)
|
66 |
+
self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
|
67 |
self.t5.eval()
|
68 |
self.clip.eval()
|
69 |
self.ae.eval()
|
70 |
self.model.eval()
|
71 |
+
|
72 |
+
if self.offload:
|
73 |
+
self.model.cpu()
|
74 |
+
torch.cuda.empty_cache()
|
75 |
+
self.ae.encoder.to(self.device)
|
76 |
|
77 |
@torch.inference_mode()
|
78 |
@spaces.GPU(duration=60)
|
79 |
def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
|
80 |
+
torch.cuda.empty_cache()
|
81 |
seed = None
|
82 |
# if seed == -1:
|
83 |
# seed = None
|
|
|
111 |
t0 = time.perf_counter()
|
112 |
|
113 |
opts.seed = None
|
114 |
+
if self.offload:
|
115 |
+
self.ae = self.ae.cpu()
|
116 |
+
torch.cuda.empty_cache()
|
117 |
+
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
|
118 |
+
|
119 |
#############inverse#######################
|
120 |
info = {}
|
121 |
info['feature'] = {}
|
|
|
129 |
inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
|
130 |
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
|
131 |
|
132 |
+
# offload TEs to CPU, load model to gpu
|
133 |
+
if self.offload:
|
134 |
+
self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
|
135 |
+
torch.cuda.empty_cache()
|
136 |
+
self.model = self.model.to(self.device)
|
137 |
+
|
138 |
# inversion initial noise
|
139 |
with torch.no_grad():
|
140 |
z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
|
|
|
146 |
# denoise initial noise
|
147 |
x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
|
148 |
|
149 |
+
# offload model, load autoencoder to gpu
|
150 |
+
if self.offload:
|
151 |
+
self.model.cpu()
|
152 |
+
torch.cuda.empty_cache()
|
153 |
+
self.ae.decoder.to(x.device)
|
154 |
+
|
155 |
# decode latents to pixel space
|
156 |
x = unpack(x.float(), opts.width, opts.height)
|
157 |
|
|
|
187 |
exif_data[ExifTags.Base.Model] = self.name
|
188 |
if self.add_sampling_metadata:
|
189 |
exif_data[ExifTags.Base.ImageDescription] = source_prompt
|
190 |
+
img.save(fn, exif=exif_data, quality=95, subsampling=0)
|
191 |
|
192 |
|
193 |
print("End Edit")
|
|
|
242 |
parser.add_argument("--port", type=int, default=41035)
|
243 |
args = parser.parse_args()
|
244 |
|
245 |
+
demo = create_demo(args.name, args.device)
|
246 |
+
demo.launch()
|