Spaces:
Running
on
A10G
Running
on
A10G
rynmurdock
commited on
Commit
β’
447c576
1
Parent(s):
e73e95e
Update app.py
Browse files
app.py
CHANGED
@@ -8,10 +8,15 @@ from sklearn.svm import LinearSVC
|
|
8 |
from sklearn import preprocessing
|
9 |
import pandas as pd
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
import random
|
12 |
import time
|
13 |
|
14 |
-
import replicate
|
15 |
import torch
|
16 |
from urllib.request import urlopen
|
17 |
|
@@ -24,11 +29,49 @@ prompt_list = [p for p in list(set(
|
|
24 |
|
25 |
start_time = time.time()
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
-
|
|
|
|
|
|
|
32 |
|
33 |
def next_image(embs, ys, calibrate_prompts):
|
34 |
global glob_idx
|
@@ -46,26 +89,11 @@ def next_image(embs, ys, calibrate_prompts):
|
|
46 |
print('######### Calibrating with sample prompts #########')
|
47 |
prompt = calibrate_prompts.pop(0)
|
48 |
print(prompt)
|
49 |
-
|
50 |
-
|
51 |
-
input={"prompt": prompt,}
|
52 |
-
)
|
53 |
-
prediction.wait()
|
54 |
-
output = prediction.output
|
55 |
-
|
56 |
-
# output = replicate.run(
|
57 |
-
# "rynmurdock/zahir:42c58addd49ab57f1e309f0b9a0f271f483bbef0470758757c623648fe989e42",
|
58 |
-
# input={"prompt": prompt,}
|
59 |
-
# )
|
60 |
-
|
61 |
-
response = requests.get(output['file1'])
|
62 |
-
image = Image.open(BytesIO(response.content))
|
63 |
-
|
64 |
-
embs.append(torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0))
|
65 |
return image, embs, ys, calibrate_prompts
|
66 |
else:
|
67 |
print('######### Roaming #########')
|
68 |
-
|
69 |
# sample only as many negatives as there are positives
|
70 |
indices = range(len(ys))
|
71 |
pos_indices = [i for i in indices if ys[i] == 1]
|
@@ -93,28 +121,8 @@ def next_image(embs, ys, calibrate_prompts):
|
|
93 |
im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
|
94 |
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
|
95 |
print(prompt)
|
96 |
-
|
97 |
-
im_emb_st = str(im_emb[0].cpu().detach().tolist())[1:-1]
|
98 |
-
|
99 |
-
prediction = deployment.predictions.create(
|
100 |
-
input={"prompt": prompt, 'im_emb': im_emb_st}
|
101 |
-
)
|
102 |
-
prediction.wait()
|
103 |
-
output = prediction.output
|
104 |
-
|
105 |
-
# output = replicate.run(
|
106 |
-
# "rynmurdock/zahir:42c58addd49ab57f1e309f0b9a0f271f483bbef0470758757c623648fe989e42",
|
107 |
-
# input={"prompt": prompt, 'im_emb': im_emb_st}
|
108 |
-
# )
|
109 |
-
|
110 |
-
response = requests.get(output['file1'])
|
111 |
-
image = Image.open(BytesIO(response.content))
|
112 |
-
|
113 |
-
|
114 |
-
im_emb = torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0)
|
115 |
embs.append(im_emb)
|
116 |
-
|
117 |
-
torch.save(lin_class.coef_, f'./{start_time}.pt')
|
118 |
return image, embs, ys, calibrate_prompts
|
119 |
|
120 |
|
@@ -195,6 +203,6 @@ with gr.Blocks(css=css) as demo:
|
|
195 |
[b4, embs, ys, calibrate_prompts],
|
196 |
[b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
|
197 |
with gr.Row():
|
198 |
-
html = gr.HTML('''<div style='text-align:center; font-size:32'>You will
|
199 |
|
200 |
demo.launch() # Share your demo with just 1 extra parameter π
|
|
|
8 |
from sklearn import preprocessing
|
9 |
import pandas as pd
|
10 |
|
11 |
+
from diffusers import LCMScheduler
|
12 |
+
from diffusers.models import ImageProjection
|
13 |
+
from patch_sdxl import SDEmb
|
14 |
+
import torch
|
15 |
+
import spaces
|
16 |
+
|
17 |
import random
|
18 |
import time
|
19 |
|
|
|
20 |
import torch
|
21 |
from urllib.request import urlopen
|
22 |
|
|
|
29 |
|
30 |
start_time = time.time()
|
31 |
|
32 |
+
####################### Setup Model
|
33 |
+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
34 |
+
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
35 |
+
pipe = SDEmb.from_pretrained(model_id, variant="fp16")
|
36 |
+
pipe.load_lora_weights(lcm_lora_id)
|
37 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
38 |
+
pipe.to(device='cuda', dtype=torch.float16)
|
39 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
40 |
+
output_hidden_state = False
|
41 |
+
#######################
|
42 |
+
|
43 |
+
@spaces.GPU
|
44 |
+
def predict(
|
45 |
+
prompt,
|
46 |
+
im_emb=None,
|
47 |
+
):
|
48 |
+
"""Run a single prediction on the model"""
|
49 |
+
with torch.no_grad():
|
50 |
+
if im_emb == None:
|
51 |
+
im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda')
|
52 |
+
else:
|
53 |
+
im_emb = torch.tensor([float(i) for i in im_emb.split(', ')]).unsqueeze(0).to(dtype=torch.float16).to('cuda')
|
54 |
+
image = pipe(
|
55 |
+
prompt=prompt,
|
56 |
+
ip_adapter_emb=im_emb,
|
57 |
+
height=1024,
|
58 |
+
width=1024,
|
59 |
+
num_inference_steps=8,
|
60 |
+
guidance_scale=0,
|
61 |
+
).images[0]
|
62 |
+
im_emb, _ = pipe.encode_image(
|
63 |
+
image, 'cuda', 1, output_hidden_state
|
64 |
+
)
|
65 |
+
return image, im_emb.to(DEVICE)
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
|
70 |
|
71 |
+
|
72 |
+
|
73 |
+
# TODO add to state instead of shared across all
|
74 |
+
glob_idx = 0
|
75 |
|
76 |
def next_image(embs, ys, calibrate_prompts):
|
77 |
global glob_idx
|
|
|
89 |
print('######### Calibrating with sample prompts #########')
|
90 |
prompt = calibrate_prompts.pop(0)
|
91 |
print(prompt)
|
92 |
+
image, img_emb = predict(prompt)
|
93 |
+
embs.append(img_emb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
return image, embs, ys, calibrate_prompts
|
95 |
else:
|
96 |
print('######### Roaming #########')
|
|
|
97 |
# sample only as many negatives as there are positives
|
98 |
indices = range(len(ys))
|
99 |
pos_indices = [i for i in indices if ys[i] == 1]
|
|
|
121 |
im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
|
122 |
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
|
123 |
print(prompt)
|
124 |
+
image, im_emb = predict(prompt, img_emb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
embs.append(im_emb)
|
|
|
|
|
126 |
return image, embs, ys, calibrate_prompts
|
127 |
|
128 |
|
|
|
203 |
[b4, embs, ys, calibrate_prompts],
|
204 |
[b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
|
205 |
with gr.Row():
|
206 |
+
html = gr.HTML('''<div style='text-align:center; font-size:32'>You will calibrate for several prompts and then roam.</ div>''')
|
207 |
|
208 |
demo.launch() # Share your demo with just 1 extra parameter π
|