Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import whisper | |
| import asyncio | |
| import httpx | |
| import tempfile | |
| import os | |
| import requests | |
| import time | |
| import threading | |
| from datetime import datetime, timedelta | |
| session = requests.Session() | |
| from interview_protocol import protocols as interview_protocols | |
| model = whisper.load_model("base") | |
| base_url = "https://llm4socialisolation-fd4082d0a518.herokuapp.com" | |
| # base_url = "http://localhost:8080" | |
| timeout = 60 | |
| concurrency_count=10 | |
| # mapping between display names and internal chatbot_type values | |
| display_to_value = { | |
| 'Echo': 'enhanced', | |
| 'Breeze': 'baseline' | |
| } | |
| value_to_display = { | |
| 'enhanced': 'Echo', | |
| 'baseline': 'Breeze' | |
| } | |
| def get_method_index(chapter, method): | |
| all_methods = [] | |
| for chap in interview_protocols.values(): | |
| all_methods.extend(chap) | |
| index = all_methods.index(method) | |
| return index | |
| async def initialization(api_key, chapter_name, topic_name, username, prompts, chatbot_type): | |
| url = f"{base_url}/api/initialization" | |
| headers = {'Content-Type': 'application/json'} | |
| data = { | |
| 'api_key': api_key, | |
| 'chapter_name': chapter_name, | |
| 'topic_name': topic_name, | |
| 'username': username, | |
| 'chatbot_type': chatbot_type, | |
| **prompts | |
| } | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| try: | |
| response = await client.post(url, json=data, headers=headers) | |
| if response.status_code == 200: | |
| return "Initialization successful." | |
| else: | |
| return f"Initialization failed: {response.text}" | |
| except asyncio.TimeoutError: | |
| print("The request timed out") | |
| return "Request timed out during initialization." | |
| except Exception as e: | |
| return f"Error in initialization: {str(e)}" | |
| def fetch_default_prompts(chatbot_type): | |
| url = f"{base_url}?chatbot_type={chatbot_type}" | |
| try: | |
| response = httpx.get(url, timeout=timeout) | |
| if response.status_code == 200: | |
| prompts = response.json() | |
| print(prompts) | |
| return prompts | |
| else: | |
| print(f"Failed to fetch prompts: {response.status_code} - {response.text}") | |
| return {} | |
| except Exception as e: | |
| print(f"Error fetching prompts: {str(e)}") | |
| return {} | |
| async def get_backend_response(api_key, patient_prompt, username, chatbot_type): | |
| url = f"{base_url}/responses/doctor" | |
| headers = {'Content-Type': 'application/json'} | |
| data = { | |
| 'username': username, | |
| 'patient_prompt': patient_prompt, | |
| 'chatbot_type': chatbot_type | |
| } | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| try: | |
| response = await client.post(url, json=data, headers=headers) | |
| if response.status_code == 200: | |
| response_data = response.json() | |
| return response_data | |
| else: | |
| return f"Failed to fetch response from backend: {response.text}" | |
| except Exception as e: | |
| return f"Error contacting backend service: {str(e)}" | |
| async def save_conversation_and_memory(username, chatbot_type): | |
| url = f"{base_url}/save/end_and_save" | |
| headers = {'Content-Type': 'application/json'} | |
| data = { | |
| 'username': username, | |
| 'chatbot_type': chatbot_type | |
| } | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| try: | |
| response = await client.post(url, json=data, headers=headers) | |
| if response.status_code == 200: | |
| response_data = response.json() | |
| return response_data.get('message', 'Saving Error!') | |
| else: | |
| return f"Failed to save conversations and memory graph: {response.text}" | |
| except Exception as e: | |
| return f"Error contacting backend service: {str(e)}" | |
| async def get_conversation_histories(username, chatbot_type): | |
| url = f"{base_url}/save/download_conversations" | |
| headers = {'Content-Type': 'application/json'} | |
| data = { | |
| 'username': username, | |
| 'chatbot_type': chatbot_type | |
| } | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| try: | |
| response = await client.post(url, json=data, headers=headers) | |
| if response.status_code == 200: | |
| conversation_data = response.json() | |
| return conversation_data | |
| else: | |
| return [] | |
| except Exception as e: | |
| return [] | |
| def download_conversations(username, chatbot_type): | |
| conversation_histories = asyncio.run(get_conversation_histories(username, chatbot_type)) | |
| files = [] | |
| temp_dir = tempfile.mkdtemp() | |
| for conversation_entry in conversation_histories: | |
| file_name = conversation_entry.get('file_name', f"Conversation_{len(files)+1}.txt") | |
| conversation = conversation_entry.get('conversation', []) | |
| conversation_text = "" | |
| for message_pair in conversation: | |
| if isinstance(message_pair, list) and len(message_pair) == 2: | |
| speaker, message = message_pair | |
| conversation_text += f"{speaker.capitalize()}: {message}\n\n" | |
| else: | |
| conversation_text += f"Unknown format: {message_pair}\n\n" | |
| temp_file_path = os.path.join(temp_dir, file_name) | |
| with open(temp_file_path, 'w') as temp_file: | |
| temp_file.write(conversation_text) | |
| files.append(temp_file_path) | |
| return files | |
| async def get_biography(username, chatbot_type): | |
| url = f"{base_url}/save/generate_autobiography" | |
| headers = {'Content-Type': 'application/json'} | |
| data = { | |
| 'username': username, | |
| 'chatbot_type': chatbot_type | |
| } | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| try: | |
| response = await client.post(url, json=data, headers=headers) | |
| if response.status_code == 200: | |
| biography_data = response.json() | |
| biography_text = biography_data.get('biography', '') | |
| return biography_text | |
| else: | |
| return "Failed to generate biography." | |
| except Exception as e: | |
| return f"Error contacting backend service: {str(e)}" | |
| def download_biography(username, chatbot_type): | |
| biography_text = asyncio.run(get_biography(username, chatbot_type)) | |
| if not biography_text or "Failed" in biography_text or "Error" in biography_text: | |
| return gr.update(value=None, visible=False), gr.update(value=biography_text, visible=True) | |
| temp_dir = tempfile.mkdtemp() | |
| temp_file_path = os.path.join(temp_dir, "biography.txt") | |
| with open(temp_file_path, 'w') as temp_file: | |
| temp_file.write(biography_text) | |
| return temp_file_path, gr.update(value=biography_text, visible=True) | |
| def transcribe_audio(audio_file): | |
| transcription = model.transcribe(audio_file)["text"] | |
| return transcription | |
| def submit_text_and_respond(edited_text, api_key, username, history, chatbot_type): | |
| response = asyncio.run(get_backend_response(api_key, edited_text, username, chatbot_type)) | |
| print('------') | |
| print(response) | |
| if isinstance(response, str): | |
| history.append((edited_text, response)) | |
| return history, "", [] | |
| doctor_response = response['doctor_response']['response'] | |
| memory_event = response.get('memory_events', []) | |
| history.append((edited_text, doctor_response)) | |
| memory_graph = update_memory_graph(memory_event) | |
| return history, "", memory_graph # Return memory_graph as output | |
| def set_initialize_button(api_key_input, chapter_name, topic_name, username_input, | |
| system_prompt_text, conv_instruction_prompt_text, therapy_prompt_text, autobio_prompt_text, chatbot_display_name): | |
| chatbot_type = display_to_value.get(chatbot_display_name, 'enhanced') | |
| prompts = { | |
| 'system_prompt': system_prompt_text, | |
| 'conv_instruction_prompt': conv_instruction_prompt_text, | |
| 'therapy_prompt': therapy_prompt_text, | |
| 'autobio_prompt': autobio_prompt_text | |
| } | |
| message = asyncio.run(initialization(api_key_input, chapter_name, topic_name, username_input, prompts, chatbot_type)) | |
| print(message) | |
| return message, api_key_input, chatbot_type | |
| def save_conversation(username, chatbot_type): | |
| response = asyncio.run(save_conversation_and_memory(username, chatbot_type)) | |
| return response | |
| def start_recording(audio_file): | |
| if not audio_file: | |
| return "" | |
| try: | |
| transcription = transcribe_audio(audio_file) | |
| return transcription | |
| except Exception as e: | |
| return f"Failed to transcribe: {str(e)}" | |
| def update_methods(chapter): | |
| return gr.update(choices=interview_protocols[chapter], value=interview_protocols[chapter][0]) | |
| def update_memory_graph(memory_data): | |
| table_data = [] | |
| for node in memory_data: | |
| table_data.append([ | |
| node.get('date', ''), | |
| node.get('topic', ''), | |
| node.get('event_description', ''), | |
| node.get('people_involved', '') | |
| ]) | |
| return table_data | |
| def update_prompts(chatbot_display_name): | |
| chatbot_type = display_to_value.get(chatbot_display_name, 'enhanced') | |
| prompts = fetch_default_prompts(chatbot_type) | |
| return ( | |
| gr.update(value=prompts.get('system_prompt', '')), | |
| gr.update(value=prompts.get('conv_instruction_prompt', '')), | |
| gr.update(value=prompts.get('therapy_prompt', '')), | |
| gr.update(value=prompts.get('autobio_generation_prompt', '')), | |
| ) | |
| def update_chatbot_type(chatbot_display_name): | |
| chatbot_type = display_to_value.get(chatbot_display_name, 'enhanced') | |
| return chatbot_type | |
| # Function to start the periodic toggle | |
| def start_timer(): | |
| target_timestamp = datetime.now() + timedelta(seconds=8 * 60) | |
| return True, target_timestamp | |
| def reset_timer(): | |
| is_running = False | |
| return is_running, "Timer remaining: 8:00" | |
| # Async function to manage periodic updates, running every second | |
| def periodic_call(is_running, target_timestamp): | |
| if is_running: | |
| prefix = 'Time remaining:' | |
| time_difference = target_timestamp - datetime.now() | |
| second_left = int(round(time_difference.total_seconds())) | |
| if second_left <= 0: | |
| second_left = 0 | |
| minutes, seconds = divmod(second_left, 60) | |
| new_remain_min = f'{minutes:02}' | |
| new_remain_second = f'{seconds:02}' | |
| new_info = f'{prefix} {new_remain_min}:{new_remain_second}' | |
| return new_info | |
| else: | |
| return 'Time remaining: 8:00' | |
| # initialize prompts with empty strings | |
| initial_prompts = {'system_prompt': '', 'conv_instruction_prompt': '', 'therapy_prompt': '', 'autobio_generation_prompt': ''} | |
| # CSS to keep the buttons small | |
| css = """ | |
| #start_button, #reset_button { | |
| padding: 4px 10px !important; | |
| font-size: 12px !important; | |
| width: auto !important; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as app: | |
| chatbot_type_state = gr.State('enhanced') | |
| api_key_state = gr.State() | |
| prompt_visibility_state = gr.State(False) | |
| is_running = gr.State() | |
| target_timestamp = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=250): | |
| gr.Markdown("## Settings") | |
| # chatbot Type Selection | |
| with gr.Box(): | |
| gr.Markdown("### Chatbot Selection") | |
| chatbot_type_dropdown = gr.Dropdown( | |
| label="Select Chatbot Type", | |
| choices=['Echo', 'Breeze'], | |
| value='Echo', | |
| ) | |
| chatbot_type_dropdown.change( | |
| fn=update_chatbot_type, | |
| inputs=[chatbot_type_dropdown], | |
| outputs=[chatbot_type_state] | |
| ) | |
| # fetch initial prompts based on the default chatbot type | |
| system_prompt_value, conv_instruction_prompt_value, therapy_prompt_value, autobio_prompt_value = update_prompts('Echo') | |
| # interview protocol selection | |
| with gr.Box(): | |
| gr.Markdown("### Interview Protocol") | |
| chapter_dropdown = gr.Dropdown( | |
| label="Select Chapter", | |
| choices=list(interview_protocols.keys()), | |
| value=list(interview_protocols.keys())[1], | |
| ) | |
| method_dropdown = gr.Dropdown( | |
| label="Select Topic", | |
| choices=interview_protocols[chapter_dropdown.value], | |
| value=interview_protocols[chapter_dropdown.value][3], | |
| ) | |
| chapter_dropdown.change( | |
| fn=update_methods, | |
| inputs=[chapter_dropdown], | |
| outputs=[method_dropdown] | |
| ) | |
| # Update states when selections change | |
| def update_chapter(chapter): | |
| return chapter | |
| def update_method(method): | |
| return method | |
| chapter_state = gr.State() | |
| method_state = gr.State() | |
| chapter_dropdown.change( | |
| fn=update_chapter, | |
| inputs=[chapter_dropdown], | |
| outputs=[chapter_state] | |
| ) | |
| method_dropdown.change( | |
| fn=update_method, | |
| inputs=[method_dropdown], | |
| outputs=[method_state] | |
| ) | |
| # customize Prompts | |
| with gr.Box(): | |
| toggle_prompts_button = gr.Button("Show Prompts") | |
| # wrap the prompts in a component with initial visibility set to False | |
| with gr.Column(visible=False) as prompt_section: | |
| gr.Markdown("### Customize Prompts") | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| placeholder="Enter the system prompt here.", | |
| value=system_prompt_value['value'] | |
| ) | |
| conv_instruction_prompt = gr.Textbox( | |
| label="Conversation Instruction Prompt", | |
| placeholder="Enter the instruction for each conversation here.", | |
| value=conv_instruction_prompt_value['value'] | |
| ) | |
| therapy_prompt = gr.Textbox( | |
| label="Therapy Prompt", | |
| placeholder="Enter the instruction for reminiscence therapy.", | |
| value=therapy_prompt_value['value'] | |
| ) | |
| autobio_prompt = gr.Textbox( | |
| label="Autobiography Generation Prompt", | |
| placeholder="Enter the instruction for autobiography generation.", | |
| value=autobio_prompt_value['value'] | |
| ) | |
| # update prompts when chatbot_type changes | |
| chatbot_type_dropdown.change( | |
| fn=update_prompts, | |
| inputs=[chatbot_type_dropdown], | |
| outputs=[system_prompt, conv_instruction_prompt, therapy_prompt, autobio_prompt] | |
| ) | |
| with gr.Box(): | |
| gr.Markdown("### User Information") | |
| username_input = gr.Textbox( | |
| label="Username", placeholder="Enter your username" | |
| ) | |
| api_key_input = gr.Textbox( | |
| label="OpenAI API Key", | |
| placeholder="Enter your openai api key", | |
| type="password" | |
| ) | |
| initialize_button = gr.Button("Initialize", variant="primary", size="large") | |
| initialization_status = gr.Textbox( | |
| label="Status", interactive=False, placeholder="Initialization status will appear here." | |
| ) | |
| initialize_button.click( | |
| fn=set_initialize_button, | |
| inputs=[api_key_input, chapter_dropdown, method_dropdown, username_input, | |
| system_prompt, conv_instruction_prompt, therapy_prompt, autobio_prompt, chatbot_type_dropdown], | |
| outputs=[initialization_status, api_key_state, chatbot_type_state], | |
| ) | |
| # define the function to toggle prompts visibility | |
| def toggle_prompts(visibility): | |
| new_visibility = not visibility | |
| button_text = "Hide Prompts" if new_visibility else "Show Prompts" | |
| return gr.update(value=button_text), gr.update(visible=new_visibility), new_visibility | |
| toggle_prompts_button.click( | |
| fn=toggle_prompts, | |
| inputs=[prompt_visibility_state], | |
| outputs=[toggle_prompts_button, prompt_section, prompt_visibility_state] | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| timer_display = gr.Textbox(value="Time remaining: 08:00", label="") | |
| start_button = gr.Button("Start Timer", elem_id="start_button") | |
| start_button.click(start_timer, outputs=[is_running, target_timestamp]).then( | |
| periodic_call, inputs=[is_running, target_timestamp], outputs=timer_display, every=1) | |
| chatbot = gr.Chatbot(label="Chat here for autobiography generation", height=500) | |
| with gr.Row(): | |
| transcription_box = gr.Textbox( | |
| label="Transcription (You can edit this)", lines=3 | |
| ) | |
| audio_input = gr.Audio( | |
| source="microphone", type="filepath", label="🎤 Record Audio" | |
| ) | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit", variant="primary", size="large") | |
| save_conversation_button = gr.Button("End and Save Conversation", variant="secondary") | |
| download_button = gr.Button("Download Conversations", variant="secondary") | |
| download_biography_button = gr.Button("Download Biography", variant="secondary") | |
| memory_graph_table = gr.Dataframe( | |
| headers=["Date", "Topic", "Description", "People Involved"], | |
| datatype=["str", "str", "str", "str"], | |
| interactive=False, | |
| label="Memory Events", | |
| max_rows=5 | |
| ) | |
| biography_textbox = gr.Textbox(label="Autobiography", visible=False) | |
| audio_input.change( | |
| fn=start_recording, | |
| inputs=[audio_input], | |
| outputs=[transcription_box] | |
| ) | |
| state = gr.State([]) | |
| submit_button.click( | |
| submit_text_and_respond, | |
| inputs=[transcription_box, api_key_state, username_input, state, chatbot_type_state], | |
| outputs=[chatbot, transcription_box, memory_graph_table] | |
| ) | |
| download_button.click( | |
| fn=download_conversations, | |
| inputs=[username_input, chatbot_type_state], | |
| outputs=gr.Files() | |
| ) | |
| download_biography_button.click( | |
| fn=download_biography, | |
| inputs=[username_input, chatbot_type_state], | |
| outputs=[gr.File(label="Biography.txt"), biography_textbox] | |
| ) | |
| save_conversation_button.click( | |
| fn=save_conversation, | |
| inputs=[username_input, chatbot_type_state], | |
| outputs=None | |
| ) | |
| app.queue() | |
| app.launch(share=True, max_threads=10) | |