beyoru commited on
Commit
b0a1757
·
verified ·
1 Parent(s): da212c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -11
app.py CHANGED
@@ -2,10 +2,9 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import numpy as np
5
- import string
6
  from huggingface_hub import InferenceClient
7
 
8
- # Initialize Inference Client for the model (Ensure you have the correct model ID)
9
  client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")
10
 
11
  # Load tokenizer and model for EOU detection
@@ -42,6 +41,7 @@ def respond(
42
  max_tokens,
43
  temperature,
44
  top_p,
 
45
  ):
46
  messages = [{"role": "system", "content": system_message}]
47
 
@@ -53,8 +53,10 @@ def respond(
53
 
54
  messages.append({"role": "user", "content": message})
55
 
56
- # Get the response from the Qwen model (e.g., for conversation generation)
57
  response = ""
 
 
 
58
  for message in client.chat_completion(
59
  messages,
60
  max_tokens=max_tokens,
@@ -64,14 +66,28 @@ def respond(
64
  ):
65
  token = message.choices[0].delta.content
66
  response += token
67
- yield response
68
 
69
- # After generating the response, get the EOU probability
70
- eou_probability = get_eou_probability(messages) # Get EOU prediction
71
- print(f"EOU Probability: {eou_probability}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Include the EOU probability in the output
74
- yield f"\nEOU Probability: {eou_probability:.2f}"
75
 
76
  # Gradio interface setup
77
  demo = gr.ChatInterface(
@@ -87,9 +103,15 @@ demo = gr.ChatInterface(
87
  step=0.05,
88
  label="Top-p (nucleus sampling)",
89
  ),
 
 
 
 
 
 
 
90
  ],
91
  )
92
 
93
  # Launch Gradio with public link sharing
94
- demo.launch(share=True)
95
-
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import numpy as np
 
5
  from huggingface_hub import InferenceClient
6
 
7
+ # Initialize Inference Client for the model (ensure you have the correct model ID)
8
  client = InferenceClient("Qwen/Qwen2.5-3B-Instruct")
9
 
10
  # Load tokenizer and model for EOU detection
 
41
  max_tokens,
42
  temperature,
43
  top_p,
44
+ eou_threshold: float = 0.9 # Probability threshold to stop or transition the conversation
45
  ):
46
  messages = [{"role": "system", "content": system_message}]
47
 
 
53
 
54
  messages.append({"role": "user", "content": message})
55
 
 
56
  response = ""
57
+ interruption_detected = False
58
+
59
+ # Streaming model response while checking for EOU
60
  for message in client.chat_completion(
61
  messages,
62
  max_tokens=max_tokens,
 
66
  ):
67
  token = message.choices[0].delta.content
68
  response += token
 
69
 
70
+ # Check for EOU probability after each response chunk
71
+ chat_ctx = [{"role": "user", "content": message} for message in history]
72
+ chat_ctx.append({"role": "assistant", "content": response})
73
+
74
+ eou_probability = get_eou_probability(chat_ctx)
75
+
76
+ # If EOU probability is above the threshold, consider it an interruption or turn end
77
+ if eou_probability > eou_threshold:
78
+ interruption_detected = True
79
+ break # Stop the response generation if EOU is high
80
+
81
+ yield response # Continue yielding the response as it's generated
82
+
83
+ if interruption_detected:
84
+ # If EOU is high, we stop the assistant response early and handle it
85
+ yield f"\nAssistant detected an interruption or end of turn. EOU Probability: {eou_probability:.2f}"
86
+
87
+ # Continue if no interruption
88
+ if not interruption_detected:
89
+ yield response
90
 
 
 
91
 
92
  # Gradio interface setup
93
  demo = gr.ChatInterface(
 
103
  step=0.05,
104
  label="Top-p (nucleus sampling)",
105
  ),
106
+ gr.Slider(
107
+ minimum=0.0,
108
+ maximum=1.0,
109
+ value=0.9,
110
+ step=0.01,
111
+ label="EOU Probability Threshold"
112
+ ),
113
  ],
114
  )
115
 
116
  # Launch Gradio with public link sharing
117
+ demo.launch(share=True)