jonathanjordan21 commited on
Commit
8cc5910
·
verified ·
1 Parent(s): 6a4790a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -400,7 +400,8 @@ def respond(
400
  message,
401
  history: list[tuple[str, str]],
402
  threshold,
403
- is_multiple
 
404
  ):
405
  global codes_emb
406
  global undetected
@@ -429,13 +430,18 @@ def respond(
429
  text_emb = model.encode(message)
430
  scores = cos_sim(codes_emb, text_emb).mean(axis=-1)#[:,0]
431
 
 
 
432
  if is_multiple:
433
  request_details = []
434
  request_numbers = []
435
  request_scores = []
436
  # request_undetected = False
437
- for i,score in enumerate(scores):
438
- if score > threshold:
 
 
 
439
  request_details.append(codes[i][6:])
440
  request_numbers.append(codes[i][:3])
441
  request_scores.append(str( round(score.tolist(), 3) ) )
@@ -614,7 +620,7 @@ with gr.Blocks() as demo:
614
  additional_inputs=[
615
  gr.Number(0.5, label="confidence threshold", show_label=True, minimum=0., maximum=1.0, step=0.1),
616
  gr.Checkbox(label="multiple", info="Allow multiple request code numbers"),
617
- gr.Number(3, label="maximum number of request codes", show_label=True, step=1, precision=0)
618
  ],
619
  # type="messages",
620
  chatbot=gr.Chatbot(height=800),
 
400
  message,
401
  history: list[tuple[str, str]],
402
  threshold,
403
+ is_multiple,
404
+ n_num,
405
  ):
406
  global codes_emb
407
  global undetected
 
430
  text_emb = model.encode(message)
431
  scores = cos_sim(codes_emb, text_emb).mean(axis=-1)#[:,0]
432
 
433
+ scores_argsort = scores.argsort(scores)[::-1]
434
+
435
  if is_multiple:
436
  request_details = []
437
  request_numbers = []
438
  request_scores = []
439
  # request_undetected = False
440
+ # for i,score in enumerate(scores):
441
+ for i in scores_argsort:
442
+ if len(request_scores) > n_num:
443
+ break
444
+ if scores[i] > threshold:
445
  request_details.append(codes[i][6:])
446
  request_numbers.append(codes[i][:3])
447
  request_scores.append(str( round(score.tolist(), 3) ) )
 
620
  additional_inputs=[
621
  gr.Number(0.5, label="confidence threshold", show_label=True, minimum=0., maximum=1.0, step=0.1),
622
  gr.Checkbox(label="multiple", info="Allow multiple request code numbers"),
623
+ gr.Number(3, label="maximum number of request codes", show_label=True, minimum=1, step=1, precision=0)
624
  ],
625
  # type="messages",
626
  chatbot=gr.Chatbot(height=800),