Spaces:
Configuration error
Configuration error
wip statewrapper chagne
Browse files- app.py +2 -2
- app_backend.py +7 -19
- loaders.py +1 -0
app.py
CHANGED
@@ -19,7 +19,7 @@ from loaders import load_default
|
|
19 |
from animation import create_gif
|
20 |
from prompts import get_random_prompts
|
21 |
|
22 |
-
device = "
|
23 |
vqgan = load_default(device)
|
24 |
vqgan.eval()
|
25 |
processor = ProcessorGradientFlow(device=device)
|
@@ -62,7 +62,7 @@ class StateWrapper:
|
|
62 |
return state, *state[0].update_requant(*args, **kwargs)
|
63 |
|
64 |
with gr.Blocks(css="styles.css") as demo:
|
65 |
-
promptoptim =
|
66 |
state = gr.State([ImageState(vqgan, promptoptim)])
|
67 |
with gr.Row():
|
68 |
with gr.Column(scale=1):
|
|
|
19 |
from animation import create_gif
|
20 |
from prompts import get_random_prompts
|
21 |
|
22 |
+
device = "cpu"
|
23 |
vqgan = load_default(device)
|
24 |
vqgan.eval()
|
25 |
processor = ProcessorGradientFlow(device=device)
|
|
|
62 |
return state, *state[0].update_requant(*args, **kwargs)
|
63 |
|
64 |
with gr.Blocks(css="styles.css") as demo:
|
65 |
+
promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
|
66 |
state = gr.State([ImageState(vqgan, promptoptim)])
|
67 |
with gr.Row():
|
68 |
with gr.Column(scale=1):
|
app_backend.py
CHANGED
@@ -174,19 +174,13 @@ class ImagePromptOptimizer(nn.Module):
|
|
174 |
clip_clone = processed_img.clone()
|
175 |
clip_clone.register_hook(self.attn_masking)
|
176 |
clip_clone.retain_grad()
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
# with torch.no_grad():
|
183 |
-
# disc_logits = self.disc(transformed_img)
|
184 |
-
# disc_loss = self.disc_loss_fn(disc_logits)
|
185 |
-
# print(f"disc_loss = {disc_loss}")
|
186 |
-
# disc_loss2 = self.disc(processed_img)
|
187 |
if log:
|
188 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
189 |
-
# wandb.log({"Discriminator Loss": disc_loss})
|
190 |
wandb.log({"CLIP Loss": clip_loss})
|
191 |
clip_loss.backward(retain_graph=True)
|
192 |
perceptual_loss.backward(retain_graph=True)
|
@@ -208,14 +202,8 @@ class ImagePromptOptimizer(nn.Module):
|
|
208 |
lpips_input = processed_img.clone()
|
209 |
lpips_input.register_hook(self.attn_masking2)
|
210 |
lpips_input.retain_grad()
|
211 |
-
|
212 |
-
|
213 |
-
# with torch.no_grad():
|
214 |
-
# disc_logits = self.disc(transformed_img)
|
215 |
-
# disc_loss = self.disc_loss_fn(disc_logits)
|
216 |
-
# print(f"disc_loss = {disc_loss}")
|
217 |
-
# disc_loss2 = self.disc(processed_img)
|
218 |
-
# print(f"disc_loss2 = {disc_loss2}")
|
219 |
if log:
|
220 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
221 |
print("LPIPS loss: ", perceptual_loss)
|
|
|
174 |
clip_clone = processed_img.clone()
|
175 |
clip_clone.register_hook(self.attn_masking)
|
176 |
clip_clone.retain_grad()
|
177 |
+
with torch.autocast("cuda"):
|
178 |
+
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone)
|
179 |
+
print("CLIP loss", clip_loss)
|
180 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
181 |
+
print("LPIPS loss: ", perceptual_loss)
|
|
|
|
|
|
|
|
|
|
|
182 |
if log:
|
183 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
|
|
184 |
wandb.log({"CLIP Loss": clip_loss})
|
185 |
clip_loss.backward(retain_graph=True)
|
186 |
perceptual_loss.backward(retain_graph=True)
|
|
|
202 |
lpips_input = processed_img.clone()
|
203 |
lpips_input.register_hook(self.attn_masking2)
|
204 |
lpips_input.retain_grad()
|
205 |
+
with torch.autocast("cuda"):
|
206 |
+
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
if log:
|
208 |
wandb.log({"Perceptual Loss": perceptual_loss})
|
209 |
print("LPIPS loss: ", perceptual_loss)
|
loaders.py
CHANGED
@@ -36,6 +36,7 @@ def load_default(device):
|
|
36 |
sd = torch.load("./vqgan_only.pt", map_location=device)
|
37 |
model.load_state_dict(sd, strict=True)
|
38 |
model.to(device)
|
|
|
39 |
return model
|
40 |
|
41 |
|
|
|
36 |
sd = torch.load("./vqgan_only.pt", map_location=device)
|
37 |
model.load_state_dict(sd, strict=True)
|
38 |
model.to(device)
|
39 |
+
del sd
|
40 |
return model
|
41 |
|
42 |
|