rynmurdock commited on
Commit
ba0dc8e
1 Parent(s): 5ec2d9d
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -236,6 +236,18 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
236
  return image, embs, img_embs, ys, calibrate_prompts
237
  else:
238
  print('######### Roaming #########')
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  im_s = get_coeff(embs, ys)
241
  rng_prompt = random.choice(prompt_list)
@@ -251,11 +263,7 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
251
  img_emb = w * learn_emb.to(dtype=torch.float16)
252
  image, img_emb = predict(prompt, im_emb=img_emb)
253
  img_embs.append(img_emb)
254
-
255
- if len(embs) > 100:
256
- embs.pop(0)
257
- img_embs.pop(0)
258
- ys.pop(0)
259
  return image, embs, img_embs, ys, calibrate_prompts
260
 
261
 
@@ -292,7 +300,6 @@ def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
292
  else:
293
  choice = 0
294
 
295
- print(img, 'img')
296
  if img is None:
297
  print('NSFW -- choice is disliked')
298
  choice = 0
 
236
  return image, embs, img_embs, ys, calibrate_prompts
237
  else:
238
  print('######### Roaming #########')
239
+
240
+ pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
241
+ neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
242
+
243
+ if len(neg_indices) > 40:
244
+ neg_indices = neg_indices[1:]
245
+ # popping first negative rating due to > 25
246
+
247
+ indices = pos_indices + neg_indices
248
+ embs = [embs[i] for i in indices]
249
+ img_embs = [img_embs[i] for i in indices]
250
+ ys = [ys[i] for i in indices]
251
 
252
  im_s = get_coeff(embs, ys)
253
  rng_prompt = random.choice(prompt_list)
 
263
  img_emb = w * learn_emb.to(dtype=torch.float16)
264
  image, img_emb = predict(prompt, im_emb=img_emb)
265
  img_embs.append(img_emb)
266
+
 
 
 
 
267
  return image, embs, img_embs, ys, calibrate_prompts
268
 
269
 
 
300
  else:
301
  choice = 0
302
 
 
303
  if img is None:
304
  print('NSFW -- choice is disliked')
305
  choice = 0