Spaces:
Running
on
A10G
Running
on
A10G
rynmurdock
commited on
Commit
•
ba0dc8e
1
Parent(s):
5ec2d9d
mechanics
Browse files
app.py
CHANGED
@@ -236,6 +236,18 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
|
|
236 |
return image, embs, img_embs, ys, calibrate_prompts
|
237 |
else:
|
238 |
print('######### Roaming #########')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
im_s = get_coeff(embs, ys)
|
241 |
rng_prompt = random.choice(prompt_list)
|
@@ -251,11 +263,7 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
|
|
251 |
img_emb = w * learn_emb.to(dtype=torch.float16)
|
252 |
image, img_emb = predict(prompt, im_emb=img_emb)
|
253 |
img_embs.append(img_emb)
|
254 |
-
|
255 |
-
if len(embs) > 100:
|
256 |
-
embs.pop(0)
|
257 |
-
img_embs.pop(0)
|
258 |
-
ys.pop(0)
|
259 |
return image, embs, img_embs, ys, calibrate_prompts
|
260 |
|
261 |
|
@@ -292,7 +300,6 @@ def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
|
|
292 |
else:
|
293 |
choice = 0
|
294 |
|
295 |
-
print(img, 'img')
|
296 |
if img is None:
|
297 |
print('NSFW -- choice is disliked')
|
298 |
choice = 0
|
|
|
236 |
return image, embs, img_embs, ys, calibrate_prompts
|
237 |
else:
|
238 |
print('######### Roaming #########')
|
239 |
+
|
240 |
+
pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
|
241 |
+
neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
|
242 |
+
|
243 |
+
if len(neg_indices) > 40:
|
244 |
+
neg_indices = neg_indices[1:]
|
245 |
+
# popping first negative rating due to > 25
|
246 |
+
|
247 |
+
indices = pos_indices + neg_indices
|
248 |
+
embs = [embs[i] for i in indices]
|
249 |
+
img_embs = [img_embs[i] for i in indices]
|
250 |
+
ys = [ys[i] for i in indices]
|
251 |
|
252 |
im_s = get_coeff(embs, ys)
|
253 |
rng_prompt = random.choice(prompt_list)
|
|
|
263 |
img_emb = w * learn_emb.to(dtype=torch.float16)
|
264 |
image, img_emb = predict(prompt, im_emb=img_emb)
|
265 |
img_embs.append(img_emb)
|
266 |
+
|
|
|
|
|
|
|
|
|
267 |
return image, embs, img_embs, ys, calibrate_prompts
|
268 |
|
269 |
|
|
|
300 |
else:
|
301 |
choice = 0
|
302 |
|
|
|
303 |
if img is None:
|
304 |
print('NSFW -- choice is disliked')
|
305 |
choice = 0
|