rynmurdock commited on
Commit
2990438
1 Parent(s): 5663ecc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -30,6 +30,14 @@ glob_idx = 0
30
  def next_image(embs, ys, calibrate_prompts):
31
  global glob_idx
32
  glob_idx = glob_idx + 1
 
 
 
 
 
 
 
 
33
  with torch.no_grad():
34
  if len(calibrate_prompts) > 0:
35
  print('######### Calibrating with sample prompts #########')
 
30
  def next_image(embs, ys, calibrate_prompts):
31
  global glob_idx
32
  glob_idx = glob_idx + 1
33
+
34
+ # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
35
+ if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1:
36
+ embs.append(torch.zeros(1, 1280))
37
+ embs.append(torch.zeros(1, 1280))
38
+ ys.append(0)
39
+ ys.append(1)
40
+
41
  with torch.no_grad():
42
  if len(calibrate_prompts) > 0:
43
  print('######### Calibrating with sample prompts #########')