Spaces:
Runtime error
Runtime error
rynmurdock
commited on
Commit
•
55d6a83
1
Parent(s):
a581500
match multimodlar
Browse files- __pycache__/safety_checker_improved.cpython-310.pyc +0 -0
- app.py +144 -65
- eigth.mp4 +0 -0
- ninth.mp4 +0 -0
- seventh.mp4 +0 -0
- tenth.mp4 +0 -0
__pycache__/safety_checker_improved.cpython-310.pyc
DELETED
Binary file (1.38 kB)
|
|
app.py
CHANGED
@@ -15,14 +15,12 @@ import matplotlib.pyplot as plt
|
|
15 |
import matplotlib
|
16 |
import logging
|
17 |
|
18 |
-
from sklearn.linear_model import Ridge
|
19 |
|
20 |
import os
|
21 |
import imageio
|
22 |
import gradio as gr
|
23 |
import numpy as np
|
24 |
from sklearn.svm import SVC
|
25 |
-
from sklearn.inspection import permutation_importance
|
26 |
from sklearn import preprocessing
|
27 |
import pandas as pd
|
28 |
from apscheduler.schedulers.background import BackgroundScheduler
|
@@ -39,14 +37,13 @@ torch.set_grad_enabled(False)
|
|
39 |
torch.backends.cuda.matmul.allow_tf32 = True
|
40 |
torch.backends.cudnn.allow_tf32 = True
|
41 |
|
42 |
-
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id'])
|
43 |
|
44 |
import spaces
|
45 |
start_time = time.time()
|
46 |
|
47 |
####################### Setup Model
|
48 |
-
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
|
49 |
-
utils.logging.disable_progress_bar
|
50 |
from transformers import CLIPTextModel
|
51 |
from huggingface_hub import hf_hub_download
|
52 |
from safetensors.torch import load_file
|
@@ -54,6 +51,7 @@ from PIL import Image
|
|
54 |
from transformers import CLIPVisionModelWithProjection
|
55 |
import uuid
|
56 |
import av
|
|
|
57 |
|
58 |
def write_video(file_name, images, fps=17):
|
59 |
container = av.open(file_name, mode="w")
|
@@ -92,6 +90,9 @@ device_map='cuda')
|
|
92 |
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
93 |
# vae = compile_unet(vae, config=config)
|
94 |
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
|
@@ -99,7 +100,8 @@ text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='
|
|
99 |
device_map='cpu').to(dtype)
|
100 |
|
101 |
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
102 |
-
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype,
|
|
|
103 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
104 |
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
105 |
pipe.set_adapters(["lcm-lora"], [.9])
|
@@ -114,7 +116,7 @@ pipe.fuse_lora()
|
|
114 |
|
115 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
|
116 |
# This IP adapter improves outputs substantially.
|
117 |
-
pipe.set_ip_adapter_scale(.
|
118 |
pipe.unet.fuse_qkv_projections()
|
119 |
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
120 |
|
@@ -122,21 +124,71 @@ pipe.to(device=DEVICE)
|
|
122 |
#pipe.unet = torch.compile(pipe.unet)
|
123 |
#pipe.vae = torch.compile(pipe.vae)
|
124 |
|
125 |
-
@spaces.GPU()
|
126 |
-
def generate_gpu(in_im_embs):
|
127 |
-
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
128 |
-
output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
129 |
-
im_emb, _ = pipe.encode_image(
|
130 |
-
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
131 |
-
)
|
132 |
-
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
133 |
-
return output, im_emb
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
name = str(uuid.uuid4()).replace("-", "")
|
141 |
path = f"/tmp/{name}.mp4"
|
142 |
|
@@ -149,19 +201,19 @@ def generate(in_im_embs):
|
|
149 |
output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
|
150 |
|
151 |
write_video(path, output.frames[0])
|
152 |
-
return path, im_emb
|
153 |
|
154 |
|
155 |
#######################
|
156 |
|
157 |
def get_user_emb(embs, ys):
|
158 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
|
|
159 |
if len(list(ys)) <= 7:
|
160 |
-
aways = [.01*torch.
|
161 |
embs += aways
|
162 |
awal = [0 for i in range(3)]
|
163 |
ys += awal
|
164 |
-
print('Fixing only one feedback class available.\n')
|
165 |
|
166 |
indices = list(range(len(embs)))
|
167 |
# sample only as many negatives as there are positives
|
@@ -176,21 +228,20 @@ def get_user_emb(embs, ys):
|
|
176 |
# this ends up adding a rating but losing an embedding, it seems.
|
177 |
# let's take off a rating if so to continue without indexing errors.
|
178 |
if len(ys) > len(embs):
|
|
|
179 |
ys.pop(-1)
|
180 |
|
181 |
feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
|
182 |
#scaler = preprocessing.StandardScaler().fit(feature_embs)
|
183 |
#feature_embs = scaler.transform(feature_embs)
|
184 |
-
|
185 |
|
186 |
if feature_embs.norm() != 0:
|
187 |
feature_embs = feature_embs / feature_embs.norm()
|
188 |
|
189 |
-
chosen_y = np.array([ys[i] for i in indices])
|
190 |
-
|
191 |
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
192 |
-
lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
|
193 |
-
coef_ = torch.tensor(lin_class.coef_, dtype=torch.
|
194 |
coef_ = coef_ / coef_.abs().max() * 3
|
195 |
|
196 |
w = 1# if len(embs) % 2 == 0 else 0
|
@@ -212,7 +263,8 @@ def pluck_img(user_id, user_emb):
|
|
212 |
best_sim = sim
|
213 |
best_row = i[1]
|
214 |
img = best_row['paths']
|
215 |
-
|
|
|
216 |
|
217 |
|
218 |
def background_next_image():
|
@@ -236,39 +288,48 @@ def background_next_image():
|
|
236 |
unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
|
237 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
238 |
|
239 |
-
# we pop previous ratings if there are >
|
240 |
-
if len(rated_from_user) >=
|
241 |
oldest = rated_from_user.iloc[0]['paths']
|
242 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
243 |
-
# we don't compute more after
|
244 |
if len(unrated_from_user) >= 10:
|
245 |
continue
|
246 |
|
247 |
-
if len(rated_rows) <
|
248 |
continue
|
249 |
|
250 |
-
embs, ys = pluck_embs_ys(uid)
|
251 |
|
252 |
user_emb = get_user_emb(embs, ys)
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
if img:
|
255 |
-
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
|
256 |
tmp_df['paths'] = [img]
|
257 |
tmp_df['embeddings'] = [embs]
|
258 |
tmp_df['user:rating'] = [{' ': ' '}]
|
259 |
tmp_df['from_user_id'] = [uid]
|
|
|
|
|
260 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
261 |
-
|
262 |
# we can free up storage by deleting the image
|
263 |
-
if len(prevs_df) >
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
# only keep
|
271 |
-
prevs_df = prevs_df[prevs_df[
|
|
|
272 |
|
273 |
def pluck_embs_ys(user_id):
|
274 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
@@ -281,21 +342,21 @@ def pluck_embs_ys(user_id):
|
|
281 |
|
282 |
embs = rated_rows['embeddings'].to_list()
|
283 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
284 |
-
|
|
|
285 |
|
286 |
def next_image(calibrate_prompts, user_id):
|
287 |
-
|
288 |
with torch.no_grad():
|
289 |
if len(calibrate_prompts) > 0:
|
290 |
cal_video = calibrate_prompts.pop(0)
|
291 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
292 |
|
293 |
-
return image, calibrate_prompts
|
294 |
else:
|
295 |
-
embs, ys = pluck_embs_ys(user_id)
|
296 |
user_emb = get_user_emb(embs, ys)
|
297 |
-
image = pluck_img(user_id, user_emb)
|
298 |
-
return image, calibrate_prompts
|
299 |
|
300 |
|
301 |
|
@@ -307,7 +368,7 @@ def next_image(calibrate_prompts, user_id):
|
|
307 |
|
308 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
309 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
310 |
-
image, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
311 |
return [
|
312 |
gr.Button(value='Like (L)', interactive=True),
|
313 |
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
@@ -326,14 +387,15 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
326 |
if choice == 'Like (L)':
|
327 |
choice = 1
|
328 |
elif choice == 'Neither (Space)':
|
329 |
-
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
330 |
-
return img, calibrate_prompts
|
331 |
else:
|
332 |
choice = 0
|
333 |
|
334 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
335 |
# TODO skip allowing rating & just continue
|
336 |
if img == None:
|
|
|
337 |
choice = 0
|
338 |
|
339 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
@@ -341,8 +403,8 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
341 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
342 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
343 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
344 |
-
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
345 |
-
return img, calibrate_prompts
|
346 |
|
347 |
css = '''.gradio-container{max-width: 700px !important}
|
348 |
#description{text-align: center}
|
@@ -426,6 +488,8 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
426 |
elem_id="video_output"
|
427 |
)
|
428 |
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
|
|
|
|
429 |
with gr.Row(equal_height=True):
|
430 |
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
431 |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
@@ -433,17 +497,17 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
433 |
b1.click(
|
434 |
choose,
|
435 |
[img, b1, calibrate_prompts, user_id],
|
436 |
-
[img, calibrate_prompts],
|
437 |
)
|
438 |
b2.click(
|
439 |
choose,
|
440 |
[img, b2, calibrate_prompts, user_id],
|
441 |
-
[img, calibrate_prompts],
|
442 |
)
|
443 |
b3.click(
|
444 |
choose,
|
445 |
[img, b3, calibrate_prompts, user_id],
|
446 |
-
[img, calibrate_prompts],
|
447 |
)
|
448 |
with gr.Row():
|
449 |
b4 = gr.Button(value='Start')
|
@@ -464,20 +528,28 @@ log = logging.getLogger('log_here')
|
|
464 |
log.setLevel(logging.ERROR)
|
465 |
|
466 |
scheduler = BackgroundScheduler()
|
467 |
-
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.
|
468 |
scheduler.start()
|
469 |
|
470 |
#thread = threading.Thread(target=background_next_image,)
|
471 |
#thread.start()
|
472 |
|
|
|
473 |
@spaces.GPU()
|
474 |
def encode_space(x):
|
475 |
im_emb, _ = pipe.encode_image(
|
476 |
image, DEVICE, 1, output_hidden_state
|
477 |
)
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
|
480 |
-
# prep our calibration
|
481 |
for im in [
|
482 |
'./first.mp4',
|
483 |
'./second.mp4',
|
@@ -485,16 +557,23 @@ for im in [
|
|
485 |
'./fourth.mp4',
|
486 |
'./fifth.mp4',
|
487 |
'./sixth.mp4',
|
|
|
|
|
|
|
|
|
488 |
]:
|
489 |
-
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating'])
|
490 |
tmp_df['paths'] = [im]
|
491 |
image = list(imageio.imiter(im))
|
492 |
image = image[len(image)//2]
|
493 |
-
im_emb = encode_space(image)
|
494 |
|
495 |
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
|
|
496 |
tmp_df['user:rating'] = [{' ': ' '}]
|
497 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
498 |
|
499 |
|
500 |
-
demo.launch(share=True)
|
|
|
|
|
|
15 |
import matplotlib
|
16 |
import logging
|
17 |
|
|
|
18 |
|
19 |
import os
|
20 |
import imageio
|
21 |
import gradio as gr
|
22 |
import numpy as np
|
23 |
from sklearn.svm import SVC
|
|
|
24 |
from sklearn import preprocessing
|
25 |
import pandas as pd
|
26 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
37 |
torch.backends.cuda.matmul.allow_tf32 = True
|
38 |
torch.backends.cudnn.allow_tf32 = True
|
39 |
|
40 |
+
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id', 'text', 'gemb'])
|
41 |
|
42 |
import spaces
|
43 |
start_time = time.time()
|
44 |
|
45 |
####################### Setup Model
|
46 |
+
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
|
|
|
47 |
from transformers import CLIPTextModel
|
48 |
from huggingface_hub import hf_hub_download
|
49 |
from safetensors.torch import load_file
|
|
|
51 |
from transformers import CLIPVisionModelWithProjection
|
52 |
import uuid
|
53 |
import av
|
54 |
+
import torchvision
|
55 |
|
56 |
def write_video(file_name, images, fps=17):
|
57 |
container = av.open(file_name, mode="w")
|
|
|
90 |
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
91 |
# vae = compile_unet(vae, config=config)
|
92 |
|
93 |
+
#finetune_path = '''/home/ryn_mote/Misc/finetune-sd1.5/dreambooth-model best'''''
|
94 |
+
#unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
|
95 |
+
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
96 |
|
97 |
|
98 |
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
|
|
|
100 |
device_map='cpu').to(dtype)
|
101 |
|
102 |
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
103 |
+
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype,
|
104 |
+
unet=unet, text_encoder=text_encoder)
|
105 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
106 |
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
107 |
pipe.set_adapters(["lcm-lora"], [.9])
|
|
|
116 |
|
117 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
|
118 |
# This IP adapter improves outputs substantially.
|
119 |
+
pipe.set_ip_adapter_scale(.6)
|
120 |
pipe.unet.fuse_qkv_projections()
|
121 |
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
122 |
|
|
|
124 |
#pipe.unet = torch.compile(pipe.unet)
|
125 |
#pipe.vae = torch.compile(pipe.vae)
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
+
#############################################################
|
129 |
+
|
130 |
+
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
|
131 |
+
|
132 |
+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
133 |
+
pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt-224', torch_dtype=dtype, quantization_config=quantization_config).eval()
|
134 |
+
processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
|
135 |
|
136 |
+
|
137 |
+
|
138 |
+
def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
|
139 |
+
inputs_embeds = pali.get_input_embeddings()(input_ids)
|
140 |
+
selected_image_feature = image_outputs.to(dtype).to(device)
|
141 |
+
image_features = pali.multi_modal_projector(selected_image_feature)
|
142 |
+
|
143 |
+
if cache_position is None:
|
144 |
+
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
145 |
+
inputs_embeds, attention_mask, labels, position_ids = pali._merge_input_ids_with_image_features(
|
146 |
+
image_features, inputs_embeds, input_ids, attention_mask, None, None, cache_position
|
147 |
+
)
|
148 |
+
return inputs_embeds
|
149 |
|
150 |
+
|
151 |
+
|
152 |
+
def generate_pali(user_emb):
|
153 |
+
prompt = 'caption en'
|
154 |
+
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
155 |
+
# we need to get im_embs taken in here.
|
156 |
+
input_len = model_inputs["input_ids"].shape[-1]
|
157 |
+
input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1),
|
158 |
+
model_inputs["input_ids"].to(device),
|
159 |
+
model_inputs["attention_mask"].to(device))
|
160 |
+
|
161 |
+
generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
162 |
+
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
163 |
+
return decoded
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
#############################################################
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
@spaces.GPU()
|
173 |
+
def generate_gpu(in_im_embs, prompt='the scene'):
|
174 |
+
with torch.no_grad():
|
175 |
+
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
176 |
+
output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
177 |
+
im_emb, _ = pipe.encode_image(
|
178 |
+
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
179 |
+
)
|
180 |
+
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
181 |
+
im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
|
182 |
+
im = torch.nn.functional.interpolate(im, (224, 224))
|
183 |
+
im = (im - .5) * 2
|
184 |
+
gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
|
185 |
+
return output, im_emb, gemb
|
186 |
+
|
187 |
+
|
188 |
+
def generate(in_im_embs, prompt='the scene'):
|
189 |
+
output, im_emb, gemb = generate_gpu(in_im_embs, prompt)
|
190 |
+
nsfw =maybe_nsfw(output.frames[0][len(output.frames[0])//2])
|
191 |
+
print(prompt)
|
192 |
name = str(uuid.uuid4()).replace("-", "")
|
193 |
path = f"/tmp/{name}.mp4"
|
194 |
|
|
|
201 |
output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
|
202 |
|
203 |
write_video(path, output.frames[0])
|
204 |
+
return path, im_emb, gemb
|
205 |
|
206 |
|
207 |
#######################
|
208 |
|
209 |
def get_user_emb(embs, ys):
|
210 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
211 |
+
|
212 |
if len(list(ys)) <= 7:
|
213 |
+
aways = [.01*torch.randn_like(embs[0]) for i in range(3)]
|
214 |
embs += aways
|
215 |
awal = [0 for i in range(3)]
|
216 |
ys += awal
|
|
|
217 |
|
218 |
indices = list(range(len(embs)))
|
219 |
# sample only as many negatives as there are positives
|
|
|
228 |
# this ends up adding a rating but losing an embedding, it seems.
|
229 |
# let's take off a rating if so to continue without indexing errors.
|
230 |
if len(ys) > len(embs):
|
231 |
+
print('ys are longer than embs; popping latest rating')
|
232 |
ys.pop(-1)
|
233 |
|
234 |
feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
|
235 |
#scaler = preprocessing.StandardScaler().fit(feature_embs)
|
236 |
#feature_embs = scaler.transform(feature_embs)
|
237 |
+
chosen_y = np.array([ys[i] for i in indices])
|
238 |
|
239 |
if feature_embs.norm() != 0:
|
240 |
feature_embs = feature_embs / feature_embs.norm()
|
241 |
|
|
|
|
|
242 |
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
243 |
+
lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs.squeeze(), chosen_y)
|
244 |
+
coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
|
245 |
coef_ = coef_ / coef_.abs().max() * 3
|
246 |
|
247 |
w = 1# if len(embs) % 2 == 0 else 0
|
|
|
263 |
best_sim = sim
|
264 |
best_row = i[1]
|
265 |
img = best_row['paths']
|
266 |
+
text = best_row.get('text', '')
|
267 |
+
return img, text
|
268 |
|
269 |
|
270 |
def background_next_image():
|
|
|
288 |
unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
|
289 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
290 |
|
291 |
+
# we pop previous ratings if there are > n
|
292 |
+
if len(rated_from_user) >= 15:
|
293 |
oldest = rated_from_user.iloc[0]['paths']
|
294 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
295 |
+
# we don't compute more after n are in the queue for them
|
296 |
if len(unrated_from_user) >= 10:
|
297 |
continue
|
298 |
|
299 |
+
if len(rated_rows) < 5:
|
300 |
continue
|
301 |
|
302 |
+
embs, ys, gembs = pluck_embs_ys(uid)
|
303 |
|
304 |
user_emb = get_user_emb(embs, ys)
|
305 |
+
|
306 |
+
if len(gembs) > 4:
|
307 |
+
user_gem = get_user_emb(gembs, ys) / 4 # TODO scale this correctly; matplotlib, etc.
|
308 |
+
text = generate_pali(user_gem)
|
309 |
+
else:
|
310 |
+
text = generate_pali(torch.zeros(1, 1152))
|
311 |
+
img, embs, new_gem = generate(user_emb, text)
|
312 |
+
|
313 |
if img:
|
314 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
|
315 |
tmp_df['paths'] = [img]
|
316 |
tmp_df['embeddings'] = [embs]
|
317 |
tmp_df['user:rating'] = [{' ': ' '}]
|
318 |
tmp_df['from_user_id'] = [uid]
|
319 |
+
tmp_df['text'] = [text]
|
320 |
+
tmp_df['gemb'] = [new_gem]
|
321 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
|
|
322 |
# we can free up storage by deleting the image
|
323 |
+
if len(prevs_df) > 500:
|
324 |
+
oldest_path = prevs_df.iloc[6]['paths']
|
325 |
+
if os.path.isfile(oldest_path):
|
326 |
+
os.remove(oldest_path)
|
327 |
+
else:
|
328 |
+
# If it fails, inform the user.
|
329 |
+
print("Error: %s file not found" % oldest_path)
|
330 |
+
# only keep 50 images & embeddings & ips, then remove oldest besides calibrating
|
331 |
+
prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:]))
|
332 |
+
|
333 |
|
334 |
def pluck_embs_ys(user_id):
|
335 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
|
|
342 |
|
343 |
embs = rated_rows['embeddings'].to_list()
|
344 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
345 |
+
gembs = rated_rows['gemb'].to_list()
|
346 |
+
return embs, ys, gembs
|
347 |
|
348 |
def next_image(calibrate_prompts, user_id):
|
|
|
349 |
with torch.no_grad():
|
350 |
if len(calibrate_prompts) > 0:
|
351 |
cal_video = calibrate_prompts.pop(0)
|
352 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
353 |
|
354 |
+
return image, calibrate_prompts, ''
|
355 |
else:
|
356 |
+
embs, ys, gembs = pluck_embs_ys(user_id)
|
357 |
user_emb = get_user_emb(embs, ys)
|
358 |
+
image, text = pluck_img(user_id, user_emb)
|
359 |
+
return image, calibrate_prompts, text
|
360 |
|
361 |
|
362 |
|
|
|
368 |
|
369 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
370 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
371 |
+
image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
|
372 |
return [
|
373 |
gr.Button(value='Like (L)', interactive=True),
|
374 |
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
|
|
387 |
if choice == 'Like (L)':
|
388 |
choice = 1
|
389 |
elif choice == 'Neither (Space)':
|
390 |
+
img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
|
391 |
+
return img, calibrate_prompts, text
|
392 |
else:
|
393 |
choice = 0
|
394 |
|
395 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
396 |
# TODO skip allowing rating & just continue
|
397 |
if img == None:
|
398 |
+
print('NSFW -- choice is disliked')
|
399 |
choice = 0
|
400 |
|
401 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
|
|
403 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
404 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
405 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
406 |
+
img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
|
407 |
+
return img, calibrate_prompts, text
|
408 |
|
409 |
css = '''.gradio-container{max-width: 700px !important}
|
410 |
#description{text-align: center}
|
|
|
488 |
elem_id="video_output"
|
489 |
)
|
490 |
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
491 |
+
with gr.Row():
|
492 |
+
text = gr.Textbox(interactive=False, visible=True, label='Text')
|
493 |
with gr.Row(equal_height=True):
|
494 |
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
495 |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
|
|
497 |
b1.click(
|
498 |
choose,
|
499 |
[img, b1, calibrate_prompts, user_id],
|
500 |
+
[img, calibrate_prompts, text],
|
501 |
)
|
502 |
b2.click(
|
503 |
choose,
|
504 |
[img, b2, calibrate_prompts, user_id],
|
505 |
+
[img, calibrate_prompts, text],
|
506 |
)
|
507 |
b3.click(
|
508 |
choose,
|
509 |
[img, b3, calibrate_prompts, user_id],
|
510 |
+
[img, calibrate_prompts, text],
|
511 |
)
|
512 |
with gr.Row():
|
513 |
b4 = gr.Button(value='Start')
|
|
|
528 |
log.setLevel(logging.ERROR)
|
529 |
|
530 |
scheduler = BackgroundScheduler()
|
531 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
|
532 |
scheduler.start()
|
533 |
|
534 |
#thread = threading.Thread(target=background_next_image,)
|
535 |
#thread.start()
|
536 |
|
537 |
+
# TODO shouldn't call this before gradio launch, yeah?
|
538 |
@spaces.GPU()
|
539 |
def encode_space(x):
|
540 |
im_emb, _ = pipe.encode_image(
|
541 |
image, DEVICE, 1, output_hidden_state
|
542 |
)
|
543 |
+
|
544 |
+
|
545 |
+
im = torchvision.transforms.ToTensor()(x).unsqueeze(0)
|
546 |
+
im = torch.nn.functional.interpolate(im, (224, 224))
|
547 |
+
im = (im - .5) * 2
|
548 |
+
gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
|
549 |
+
|
550 |
+
return im_emb.detach().to('cpu').to(torch.float32), gemb
|
551 |
|
552 |
+
# prep our calibration videos
|
553 |
for im in [
|
554 |
'./first.mp4',
|
555 |
'./second.mp4',
|
|
|
557 |
'./fourth.mp4',
|
558 |
'./fifth.mp4',
|
559 |
'./sixth.mp4',
|
560 |
+
'./seventh.mp4',
|
561 |
+
'./eigth.mp4',
|
562 |
+
'./ninth.mp4',
|
563 |
+
'./tenth.mp4',
|
564 |
]:
|
565 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
|
566 |
tmp_df['paths'] = [im]
|
567 |
image = list(imageio.imiter(im))
|
568 |
image = image[len(image)//2]
|
569 |
+
im_emb, gemb = encode_space(image)
|
570 |
|
571 |
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
572 |
+
tmp_df['gemb'] = [gemb.detach().to('cpu')]
|
573 |
tmp_df['user:rating'] = [{' ': ' '}]
|
574 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
575 |
|
576 |
|
577 |
+
demo.launch(share=True, server_port=8443)
|
578 |
+
|
579 |
+
|
eigth.mp4
ADDED
Binary file (47.7 kB). View file
|
|
ninth.mp4
ADDED
Binary file (255 kB). View file
|
|
seventh.mp4
ADDED
Binary file (50 kB). View file
|
|
tenth.mp4
ADDED
Binary file (129 kB). View file
|
|