Update app.py
Browse files
app.py
CHANGED
@@ -78,7 +78,7 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
|
78 |
self.quit = asyncio.Event()
|
79 |
self.session = None
|
80 |
self.last_frame_time = 0
|
81 |
-
self.conversation_history = []
|
82 |
self.latest_text = ""
|
83 |
|
84 |
|
@@ -126,21 +126,18 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
|
126 |
):
|
127 |
self.audio_queue.put_nowait(audio_response)
|
128 |
|
129 |
-
async def receive(self, frame: tuple[int, np.ndarray], text_input: str) -> None:
|
130 |
_, array = frame
|
131 |
array = array.squeeze()
|
132 |
-
audio_message = encode_audio(array)
|
133 |
if self.session:
|
134 |
if text_input: # Checks if text was inputted
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
self.conversation_history.append({"role": "user", "content": str(base64.b64encode(array.tobytes()).decode("UTF-8"))}) # Stores conversation in history
|
143 |
-
|
144 |
|
145 |
async def emit(self) -> AudioEmitType:
|
146 |
if not self.args_set.is_set():
|
@@ -149,16 +146,14 @@ class GeminiHandler(AsyncAudioVideoStreamHandler):
|
|
149 |
asyncio.create_task(self.connect(self.latest_args[1]))
|
150 |
array = await self.audio_queue.get()
|
151 |
return (self.output_sample_rate, array)
|
152 |
-
|
153 |
def set_text(self, text):
|
154 |
self.latest_text = text
|
155 |
-
|
156 |
-
def get_text(self):
|
157 |
-
return self.latest_text
|
158 |
-
|
159 |
def clear_text(self):
|
160 |
-
|
161 |
-
|
|
|
162 |
|
163 |
def shutdown(self) -> None:
|
164 |
self.quit.set()
|
@@ -205,18 +200,18 @@ with gr.Blocks(css=css) as demo:
|
|
205 |
image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
|
206 |
text_input = gr.Textbox(label="Text Message", placeholder="Type your message here")
|
207 |
send_button = gr.Button("Send")
|
|
|
208 |
handler = GeminiHandler()
|
209 |
-
send_button.click(handler.set_text,inputs=[text_input], outputs=[])
|
210 |
send_button.click(handler.clear_text, inputs=[], outputs=[text_input])
|
211 |
webrtc.stream(
|
212 |
handler,
|
213 |
-
inputs=[webrtc, api_key, image_input,
|
214 |
outputs=[webrtc],
|
215 |
time_limit=90,
|
216 |
concurrency_limit=2,
|
217 |
)
|
218 |
|
219 |
|
220 |
-
|
221 |
if __name__ == "__main__":
|
222 |
demo.launch()
|
|
|
78 |
self.quit = asyncio.Event()
|
79 |
self.session = None
|
80 |
self.last_frame_time = 0
|
81 |
+
self.conversation_history = [] # Added conversation history
|
82 |
self.latest_text = ""
|
83 |
|
84 |
|
|
|
126 |
):
|
127 |
self.audio_queue.put_nowait(audio_response)
|
128 |
|
129 |
+
async def receive(self, frame: tuple[int, np.ndarray], text_input: str) -> None: # Added text_input here
|
130 |
_, array = frame
|
131 |
array = array.squeeze()
|
|
|
132 |
if self.session:
|
133 |
if text_input: # Checks if text was inputted
|
134 |
+
full_prompt = PROMPT_BASE + "\n\n" + "User: " + text_input
|
135 |
+
await self.session.send({"mime_type": "text", "data": full_prompt})
|
136 |
+
self.conversation_history.append({"role": "user", "content": text_input}) # Add text conversation
|
137 |
+
elif array.size: # Checks if audio was received
|
138 |
+
full_prompt = PROMPT_BASE + "\n\n" + "User: " + str(base64.b64encode(array.tobytes()).decode("UTF-8"))
|
139 |
+
await self.session.send({"mime_type": "text", "data": full_prompt})
|
140 |
+
self.conversation_history.append({"role": "user", "content": str(base64.b64encode(array.tobytes()).decode("UTF-8"))})
|
|
|
|
|
141 |
|
142 |
async def emit(self) -> AudioEmitType:
|
143 |
if not self.args_set.is_set():
|
|
|
146 |
asyncio.create_task(self.connect(self.latest_args[1]))
|
147 |
array = await self.audio_queue.get()
|
148 |
return (self.output_sample_rate, array)
|
149 |
+
|
150 |
def set_text(self, text):
|
151 |
self.latest_text = text
|
152 |
+
|
|
|
|
|
|
|
153 |
def clear_text(self):
|
154 |
+
self.latest_text = ""
|
155 |
+
return ""
|
156 |
+
|
157 |
|
158 |
def shutdown(self) -> None:
|
159 |
self.quit.set()
|
|
|
200 |
image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"])
|
201 |
text_input = gr.Textbox(label="Text Message", placeholder="Type your message here")
|
202 |
send_button = gr.Button("Send")
|
203 |
+
|
204 |
handler = GeminiHandler()
|
205 |
+
send_button.click(handler.set_text, inputs=[text_input], outputs=[])
|
206 |
send_button.click(handler.clear_text, inputs=[], outputs=[text_input])
|
207 |
webrtc.stream(
|
208 |
handler,
|
209 |
+
inputs=[webrtc, api_key, image_input, text_input],
|
210 |
outputs=[webrtc],
|
211 |
time_limit=90,
|
212 |
concurrency_limit=2,
|
213 |
)
|
214 |
|
215 |
|
|
|
216 |
if __name__ == "__main__":
|
217 |
demo.launch()
|