rynmurdock commited on
Commit
55d6a83
1 Parent(s): a581500

match multimodlar

Browse files
__pycache__/safety_checker_improved.cpython-310.pyc DELETED
Binary file (1.38 kB)
 
app.py CHANGED
@@ -15,14 +15,12 @@ import matplotlib.pyplot as plt
15
  import matplotlib
16
  import logging
17
 
18
- from sklearn.linear_model import Ridge
19
 
20
  import os
21
  import imageio
22
  import gradio as gr
23
  import numpy as np
24
  from sklearn.svm import SVC
25
- from sklearn.inspection import permutation_importance
26
  from sklearn import preprocessing
27
  import pandas as pd
28
  from apscheduler.schedulers.background import BackgroundScheduler
@@ -39,14 +37,13 @@ torch.set_grad_enabled(False)
39
  torch.backends.cuda.matmul.allow_tf32 = True
40
  torch.backends.cudnn.allow_tf32 = True
41
 
42
- prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id'])
43
 
44
  import spaces
45
  start_time = time.time()
46
 
47
  ####################### Setup Model
48
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, utils
49
- utils.logging.disable_progress_bar
50
  from transformers import CLIPTextModel
51
  from huggingface_hub import hf_hub_download
52
  from safetensors.torch import load_file
@@ -54,6 +51,7 @@ from PIL import Image
54
  from transformers import CLIPVisionModelWithProjection
55
  import uuid
56
  import av
 
57
 
58
  def write_video(file_name, images, fps=17):
59
  container = av.open(file_name, mode="w")
@@ -92,6 +90,9 @@ device_map='cuda')
92
  # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
93
  # vae = compile_unet(vae, config=config)
94
 
 
 
 
95
 
96
 
97
  unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
@@ -99,7 +100,8 @@ text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='
99
  device_map='cpu').to(dtype)
100
 
101
  adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
102
- pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
 
103
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
104
  pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
105
  pipe.set_adapters(["lcm-lora"], [.9])
@@ -114,7 +116,7 @@ pipe.fuse_lora()
114
 
115
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
116
  # This IP adapter improves outputs substantially.
117
- pipe.set_ip_adapter_scale(.9)
118
  pipe.unet.fuse_qkv_projections()
119
  #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
120
 
@@ -122,21 +124,71 @@ pipe.to(device=DEVICE)
122
  #pipe.unet = torch.compile(pipe.unet)
123
  #pipe.vae = torch.compile(pipe.vae)
124
 
125
- @spaces.GPU()
126
- def generate_gpu(in_im_embs):
127
- in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
128
- output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
129
- im_emb, _ = pipe.encode_image(
130
- output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
131
- )
132
- im_emb = im_emb.detach().to('cpu').to(torch.float32)
133
- return output, im_emb
134
 
 
 
 
 
 
 
 
135
 
136
- def generate(in_im_embs):
137
- output, im_emb = generate_gpu(in_im_embs)
138
- nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2])
 
 
 
 
 
 
 
 
 
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  name = str(uuid.uuid4()).replace("-", "")
141
  path = f"/tmp/{name}.mp4"
142
 
@@ -149,19 +201,19 @@ def generate(in_im_embs):
149
  output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
150
 
151
  write_video(path, output.frames[0])
152
- return path, im_emb
153
 
154
 
155
  #######################
156
 
157
  def get_user_emb(embs, ys):
158
  # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
 
159
  if len(list(ys)) <= 7:
160
- aways = [.01*torch.randn(1280) for i in range(3)]
161
  embs += aways
162
  awal = [0 for i in range(3)]
163
  ys += awal
164
- print('Fixing only one feedback class available.\n')
165
 
166
  indices = list(range(len(embs)))
167
  # sample only as many negatives as there are positives
@@ -176,21 +228,20 @@ def get_user_emb(embs, ys):
176
  # this ends up adding a rating but losing an embedding, it seems.
177
  # let's take off a rating if so to continue without indexing errors.
178
  if len(ys) > len(embs):
 
179
  ys.pop(-1)
180
 
181
  feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
182
  #scaler = preprocessing.StandardScaler().fit(feature_embs)
183
  #feature_embs = scaler.transform(feature_embs)
184
-
185
 
186
  if feature_embs.norm() != 0:
187
  feature_embs = feature_embs / feature_embs.norm()
188
 
189
- chosen_y = np.array([ys[i] for i in indices])
190
-
191
  #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
192
- lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
193
- coef_ = torch.tensor(lin_class.coef_, dtype=torch.double).detach().to('cpu')
194
  coef_ = coef_ / coef_.abs().max() * 3
195
 
196
  w = 1# if len(embs) % 2 == 0 else 0
@@ -212,7 +263,8 @@ def pluck_img(user_id, user_emb):
212
  best_sim = sim
213
  best_row = i[1]
214
  img = best_row['paths']
215
- return img
 
216
 
217
 
218
  def background_next_image():
@@ -236,39 +288,48 @@ def background_next_image():
236
  unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
237
  rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
238
 
239
- # we pop previous ratings if there are > 10
240
- if len(rated_from_user) >= 10:
241
  oldest = rated_from_user.iloc[0]['paths']
242
  prevs_df = prevs_df[prevs_df['paths'] != oldest]
243
- # we don't compute more after 10 are in the queue for them
244
  if len(unrated_from_user) >= 10:
245
  continue
246
 
247
- if len(rated_rows) < 4:
248
  continue
249
 
250
- embs, ys = pluck_embs_ys(uid)
251
 
252
  user_emb = get_user_emb(embs, ys)
253
- img, embs = generate(user_emb)
 
 
 
 
 
 
 
254
  if img:
255
- tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
256
  tmp_df['paths'] = [img]
257
  tmp_df['embeddings'] = [embs]
258
  tmp_df['user:rating'] = [{' ': ' '}]
259
  tmp_df['from_user_id'] = [uid]
 
 
260
  prevs_df = pd.concat((prevs_df, tmp_df))
261
-
262
  # we can free up storage by deleting the image
263
- if len(prevs_df) > 30:
264
- cands = prevs_df.iloc[6:]
265
- cands['sum_bad_ratings'] = [sum([int(t==0) for t in i.values()]) for i in cands['user:rating']]
266
- worst_row = cands.loc[cands['sum_bad_ratings']==cands['sum_bad_ratings'].max()].iloc[0]
267
- worst_path = worst_row['paths']
268
- if os.path.isfile(worst_path):
269
- os.remove(worst_path)
270
- # only keep x images & embeddings & ips, then remove the most often disliked besides calibrating
271
- prevs_df = prevs_df[prevs_df['paths'] != worst_path]
 
272
 
273
  def pluck_embs_ys(user_id):
274
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
@@ -281,21 +342,21 @@ def pluck_embs_ys(user_id):
281
 
282
  embs = rated_rows['embeddings'].to_list()
283
  ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
284
- return embs, ys
 
285
 
286
  def next_image(calibrate_prompts, user_id):
287
-
288
  with torch.no_grad():
289
  if len(calibrate_prompts) > 0:
290
  cal_video = calibrate_prompts.pop(0)
291
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
292
 
293
- return image, calibrate_prompts
294
  else:
295
- embs, ys = pluck_embs_ys(user_id)
296
  user_emb = get_user_emb(embs, ys)
297
- image = pluck_img(user_id, user_emb)
298
- return image, calibrate_prompts
299
 
300
 
301
 
@@ -307,7 +368,7 @@ def next_image(calibrate_prompts, user_id):
307
 
308
  def start(_, calibrate_prompts, user_id, request: gr.Request):
309
  user_id = int(str(time.time())[-7:].replace('.', ''))
310
- image, calibrate_prompts = next_image(calibrate_prompts, user_id)
311
  return [
312
  gr.Button(value='Like (L)', interactive=True),
313
  gr.Button(value='Neither (Space)', interactive=True, visible=False),
@@ -326,14 +387,15 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
326
  if choice == 'Like (L)':
327
  choice = 1
328
  elif choice == 'Neither (Space)':
329
- img, calibrate_prompts = next_image(calibrate_prompts, user_id)
330
- return img, calibrate_prompts
331
  else:
332
  choice = 0
333
 
334
  # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
335
  # TODO skip allowing rating & just continue
336
  if img == None:
 
337
  choice = 0
338
 
339
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
@@ -341,8 +403,8 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
341
  if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
342
  prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
343
  prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
344
- img, calibrate_prompts = next_image(calibrate_prompts, user_id)
345
- return img, calibrate_prompts
346
 
347
  css = '''.gradio-container{max-width: 700px !important}
348
  #description{text-align: center}
@@ -426,6 +488,8 @@ Explore the latent space without text prompts based on your preferences. Learn m
426
  elem_id="video_output"
427
  )
428
  img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
 
 
429
  with gr.Row(equal_height=True):
430
  b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
431
  b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
@@ -433,17 +497,17 @@ Explore the latent space without text prompts based on your preferences. Learn m
433
  b1.click(
434
  choose,
435
  [img, b1, calibrate_prompts, user_id],
436
- [img, calibrate_prompts],
437
  )
438
  b2.click(
439
  choose,
440
  [img, b2, calibrate_prompts, user_id],
441
- [img, calibrate_prompts],
442
  )
443
  b3.click(
444
  choose,
445
  [img, b3, calibrate_prompts, user_id],
446
- [img, calibrate_prompts],
447
  )
448
  with gr.Row():
449
  b4 = gr.Button(value='Start')
@@ -464,20 +528,28 @@ log = logging.getLogger('log_here')
464
  log.setLevel(logging.ERROR)
465
 
466
  scheduler = BackgroundScheduler()
467
- scheduler.add_job(func=background_next_image, trigger="interval", seconds=.3)
468
  scheduler.start()
469
 
470
  #thread = threading.Thread(target=background_next_image,)
471
  #thread.start()
472
 
 
473
  @spaces.GPU()
474
  def encode_space(x):
475
  im_emb, _ = pipe.encode_image(
476
  image, DEVICE, 1, output_hidden_state
477
  )
478
- return im_emb.detach().to('cpu').to(torch.float32)
 
 
 
 
 
 
 
479
 
480
- # prep our calibration prompts
481
  for im in [
482
  './first.mp4',
483
  './second.mp4',
@@ -485,16 +557,23 @@ for im in [
485
  './fourth.mp4',
486
  './fifth.mp4',
487
  './sixth.mp4',
 
 
 
 
488
  ]:
489
- tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating'])
490
  tmp_df['paths'] = [im]
491
  image = list(imageio.imiter(im))
492
  image = image[len(image)//2]
493
- im_emb = encode_space(image)
494
 
495
  tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
 
496
  tmp_df['user:rating'] = [{' ': ' '}]
497
  prevs_df = pd.concat((prevs_df, tmp_df))
498
 
499
 
500
- demo.launch(share=True)
 
 
 
15
  import matplotlib
16
  import logging
17
 
 
18
 
19
  import os
20
  import imageio
21
  import gradio as gr
22
  import numpy as np
23
  from sklearn.svm import SVC
 
24
  from sklearn import preprocessing
25
  import pandas as pd
26
  from apscheduler.schedulers.background import BackgroundScheduler
 
37
  torch.backends.cuda.matmul.allow_tf32 = True
38
  torch.backends.cudnn.allow_tf32 = True
39
 
40
+ prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id', 'text', 'gemb'])
41
 
42
  import spaces
43
  start_time = time.time()
44
 
45
  ####################### Setup Model
46
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
 
47
  from transformers import CLIPTextModel
48
  from huggingface_hub import hf_hub_download
49
  from safetensors.torch import load_file
 
51
  from transformers import CLIPVisionModelWithProjection
52
  import uuid
53
  import av
54
+ import torchvision
55
 
56
  def write_video(file_name, images, fps=17):
57
  container = av.open(file_name, mode="w")
 
90
  # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
91
  # vae = compile_unet(vae, config=config)
92
 
93
+ #finetune_path = '''/home/ryn_mote/Misc/finetune-sd1.5/dreambooth-model best'''''
94
+ #unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
95
+ #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
96
 
97
 
98
  unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
 
100
  device_map='cpu').to(dtype)
101
 
102
  adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
103
+ pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype,
104
+ unet=unet, text_encoder=text_encoder)
105
  pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
106
  pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
107
  pipe.set_adapters(["lcm-lora"], [.9])
 
116
 
117
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
118
  # This IP adapter improves outputs substantially.
119
+ pipe.set_ip_adapter_scale(.6)
120
  pipe.unet.fuse_qkv_projections()
121
  #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
122
 
 
124
  #pipe.unet = torch.compile(pipe.unet)
125
  #pipe.vae = torch.compile(pipe.vae)
126
 
 
 
 
 
 
 
 
 
 
127
 
128
+ #############################################################
129
+
130
+ from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
131
+
132
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
133
+ pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt-224', torch_dtype=dtype, quantization_config=quantization_config).eval()
134
+ processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
135
 
136
+
137
+
138
+ def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
139
+ inputs_embeds = pali.get_input_embeddings()(input_ids)
140
+ selected_image_feature = image_outputs.to(dtype).to(device)
141
+ image_features = pali.multi_modal_projector(selected_image_feature)
142
+
143
+ if cache_position is None:
144
+ cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
145
+ inputs_embeds, attention_mask, labels, position_ids = pali._merge_input_ids_with_image_features(
146
+ image_features, inputs_embeds, input_ids, attention_mask, None, None, cache_position
147
+ )
148
+ return inputs_embeds
149
 
150
+
151
+
152
+ def generate_pali(user_emb):
153
+ prompt = 'caption en'
154
+ model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
155
+ # we need to get im_embs taken in here.
156
+ input_len = model_inputs["input_ids"].shape[-1]
157
+ input_embeds = to_wanted_embs(user_emb.squeeze()[None, None, :].repeat(1, 256, 1),
158
+ model_inputs["input_ids"].to(device),
159
+ model_inputs["attention_mask"].to(device))
160
+
161
+ generation = pali.generate(max_new_tokens=100, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
162
+ decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
163
+ return decoded
164
+
165
+
166
+
167
+
168
+ #############################################################
169
+
170
+
171
+
172
+ @spaces.GPU()
173
+ def generate_gpu(in_im_embs, prompt='the scene'):
174
+ with torch.no_grad():
175
+ in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
176
+ output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
177
+ im_emb, _ = pipe.encode_image(
178
+ output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
179
+ )
180
+ im_emb = im_emb.detach().to('cpu').to(torch.float32)
181
+ im = torchvision.transforms.ToTensor()(output.frames[0][len(output.frames[0])//2]).unsqueeze(0)
182
+ im = torch.nn.functional.interpolate(im, (224, 224))
183
+ im = (im - .5) * 2
184
+ gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
185
+ return output, im_emb, gemb
186
+
187
+
188
+ def generate(in_im_embs, prompt='the scene'):
189
+ output, im_emb, gemb = generate_gpu(in_im_embs, prompt)
190
+ nsfw =maybe_nsfw(output.frames[0][len(output.frames[0])//2])
191
+ print(prompt)
192
  name = str(uuid.uuid4()).replace("-", "")
193
  path = f"/tmp/{name}.mp4"
194
 
 
201
  output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
202
 
203
  write_video(path, output.frames[0])
204
+ return path, im_emb, gemb
205
 
206
 
207
  #######################
208
 
209
  def get_user_emb(embs, ys):
210
  # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
211
+
212
  if len(list(ys)) <= 7:
213
+ aways = [.01*torch.randn_like(embs[0]) for i in range(3)]
214
  embs += aways
215
  awal = [0 for i in range(3)]
216
  ys += awal
 
217
 
218
  indices = list(range(len(embs)))
219
  # sample only as many negatives as there are positives
 
228
  # this ends up adding a rating but losing an embedding, it seems.
229
  # let's take off a rating if so to continue without indexing errors.
230
  if len(ys) > len(embs):
231
+ print('ys are longer than embs; popping latest rating')
232
  ys.pop(-1)
233
 
234
  feature_embs = torch.stack([embs[i].squeeze().to('cpu') for i in indices]).to('cpu')
235
  #scaler = preprocessing.StandardScaler().fit(feature_embs)
236
  #feature_embs = scaler.transform(feature_embs)
237
+ chosen_y = np.array([ys[i] for i in indices])
238
 
239
  if feature_embs.norm() != 0:
240
  feature_embs = feature_embs / feature_embs.norm()
241
 
 
 
242
  #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
243
+ lin_class = SVC(max_iter=20, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs.squeeze(), chosen_y)
244
+ coef_ = torch.tensor(lin_class.coef_, dtype=torch.float32).detach().to('cpu')
245
  coef_ = coef_ / coef_.abs().max() * 3
246
 
247
  w = 1# if len(embs) % 2 == 0 else 0
 
263
  best_sim = sim
264
  best_row = i[1]
265
  img = best_row['paths']
266
+ text = best_row.get('text', '')
267
+ return img, text
268
 
269
 
270
  def background_next_image():
 
288
  unrated_from_user = not_rated_rows[[i[1]['from_user_id'] == uid for i in not_rated_rows.iterrows()]]
289
  rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
290
 
291
+ # we pop previous ratings if there are > n
292
+ if len(rated_from_user) >= 15:
293
  oldest = rated_from_user.iloc[0]['paths']
294
  prevs_df = prevs_df[prevs_df['paths'] != oldest]
295
+ # we don't compute more after n are in the queue for them
296
  if len(unrated_from_user) >= 10:
297
  continue
298
 
299
+ if len(rated_rows) < 5:
300
  continue
301
 
302
+ embs, ys, gembs = pluck_embs_ys(uid)
303
 
304
  user_emb = get_user_emb(embs, ys)
305
+
306
+ if len(gembs) > 4:
307
+ user_gem = get_user_emb(gembs, ys) / 4 # TODO scale this correctly; matplotlib, etc.
308
+ text = generate_pali(user_gem)
309
+ else:
310
+ text = generate_pali(torch.zeros(1, 1152))
311
+ img, embs, new_gem = generate(user_emb, text)
312
+
313
  if img:
314
+ tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
315
  tmp_df['paths'] = [img]
316
  tmp_df['embeddings'] = [embs]
317
  tmp_df['user:rating'] = [{' ': ' '}]
318
  tmp_df['from_user_id'] = [uid]
319
+ tmp_df['text'] = [text]
320
+ tmp_df['gemb'] = [new_gem]
321
  prevs_df = pd.concat((prevs_df, tmp_df))
 
322
  # we can free up storage by deleting the image
323
+ if len(prevs_df) > 500:
324
+ oldest_path = prevs_df.iloc[6]['paths']
325
+ if os.path.isfile(oldest_path):
326
+ os.remove(oldest_path)
327
+ else:
328
+ # If it fails, inform the user.
329
+ print("Error: %s file not found" % oldest_path)
330
+ # only keep 50 images & embeddings & ips, then remove oldest besides calibrating
331
+ prevs_df = pd.concat((prevs_df.iloc[:6], prevs_df.iloc[7:]))
332
+
333
 
334
  def pluck_embs_ys(user_id):
335
  rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
 
342
 
343
  embs = rated_rows['embeddings'].to_list()
344
  ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
345
+ gembs = rated_rows['gemb'].to_list()
346
+ return embs, ys, gembs
347
 
348
  def next_image(calibrate_prompts, user_id):
 
349
  with torch.no_grad():
350
  if len(calibrate_prompts) > 0:
351
  cal_video = calibrate_prompts.pop(0)
352
  image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
353
 
354
+ return image, calibrate_prompts, ''
355
  else:
356
+ embs, ys, gembs = pluck_embs_ys(user_id)
357
  user_emb = get_user_emb(embs, ys)
358
+ image, text = pluck_img(user_id, user_emb)
359
+ return image, calibrate_prompts, text
360
 
361
 
362
 
 
368
 
369
  def start(_, calibrate_prompts, user_id, request: gr.Request):
370
  user_id = int(str(time.time())[-7:].replace('.', ''))
371
+ image, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
372
  return [
373
  gr.Button(value='Like (L)', interactive=True),
374
  gr.Button(value='Neither (Space)', interactive=True, visible=False),
 
387
  if choice == 'Like (L)':
388
  choice = 1
389
  elif choice == 'Neither (Space)':
390
+ img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
391
+ return img, calibrate_prompts, text
392
  else:
393
  choice = 0
394
 
395
  # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
396
  # TODO skip allowing rating & just continue
397
  if img == None:
398
+ print('NSFW -- choice is disliked')
399
  choice = 0
400
 
401
  row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
 
403
  if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
404
  prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
405
  prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
406
+ img, calibrate_prompts, text = next_image(calibrate_prompts, user_id)
407
+ return img, calibrate_prompts, text
408
 
409
  css = '''.gradio-container{max-width: 700px !important}
410
  #description{text-align: center}
 
488
  elem_id="video_output"
489
  )
490
  img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
491
+ with gr.Row():
492
+ text = gr.Textbox(interactive=False, visible=True, label='Text')
493
  with gr.Row(equal_height=True):
494
  b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
495
  b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
 
497
  b1.click(
498
  choose,
499
  [img, b1, calibrate_prompts, user_id],
500
+ [img, calibrate_prompts, text],
501
  )
502
  b2.click(
503
  choose,
504
  [img, b2, calibrate_prompts, user_id],
505
+ [img, calibrate_prompts, text],
506
  )
507
  b3.click(
508
  choose,
509
  [img, b3, calibrate_prompts, user_id],
510
+ [img, calibrate_prompts, text],
511
  )
512
  with gr.Row():
513
  b4 = gr.Button(value='Start')
 
528
  log.setLevel(logging.ERROR)
529
 
530
  scheduler = BackgroundScheduler()
531
+ scheduler.add_job(func=background_next_image, trigger="interval", seconds=.5)
532
  scheduler.start()
533
 
534
  #thread = threading.Thread(target=background_next_image,)
535
  #thread.start()
536
 
537
+ # TODO shouldn't call this before gradio launch, yeah?
538
  @spaces.GPU()
539
  def encode_space(x):
540
  im_emb, _ = pipe.encode_image(
541
  image, DEVICE, 1, output_hidden_state
542
  )
543
+
544
+
545
+ im = torchvision.transforms.ToTensor()(x).unsqueeze(0)
546
+ im = torch.nn.functional.interpolate(im, (224, 224))
547
+ im = (im - .5) * 2
548
+ gemb = pali.vision_tower(im.to(device).to(dtype)).last_hidden_state.detach().to('cpu').to(torch.float32).mean(1)
549
+
550
+ return im_emb.detach().to('cpu').to(torch.float32), gemb
551
 
552
+ # prep our calibration videos
553
  for im in [
554
  './first.mp4',
555
  './second.mp4',
 
557
  './fourth.mp4',
558
  './fifth.mp4',
559
  './sixth.mp4',
560
+ './seventh.mp4',
561
+ './eigth.mp4',
562
+ './ninth.mp4',
563
+ './tenth.mp4',
564
  ]:
565
+ tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
566
  tmp_df['paths'] = [im]
567
  image = list(imageio.imiter(im))
568
  image = image[len(image)//2]
569
+ im_emb, gemb = encode_space(image)
570
 
571
  tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
572
+ tmp_df['gemb'] = [gemb.detach().to('cpu')]
573
  tmp_df['user:rating'] = [{' ': ' '}]
574
  prevs_df = pd.concat((prevs_df, tmp_df))
575
 
576
 
577
+ demo.launch(share=True, server_port=8443)
578
+
579
+
eigth.mp4 ADDED
Binary file (47.7 kB). View file
 
ninth.mp4 ADDED
Binary file (255 kB). View file
 
seventh.mp4 ADDED
Binary file (50 kB). View file
 
tenth.mp4 ADDED
Binary file (129 kB). View file