Spaces:
Running
Running
| import json | |
| from typing import Generator, List | |
| import gradio as gr | |
| from openai import OpenAI | |
| from crop_utils import get_image_crop | |
| from prompts import ( | |
| get_chat_system_prompt, | |
| get_live_event_system_prompt, | |
| get_live_event_user_prompt, | |
| get_street_interview_prompt, | |
| get_street_interview_system_prompt, | |
| ) | |
| from transcript import TranscriptProcessor | |
| from utils import css, get_transcript_for_url, head | |
| from utils import openai_tools as tools | |
| from utils import setup_openai_key | |
| client = OpenAI() | |
| def get_initial_analysis( | |
| transcript_processor: TranscriptProcessor, cid, rsid, origin, ct, uid | |
| ) -> Generator[str, None, None]: | |
| """Perform initial analysis of the transcript using OpenAI.""" | |
| try: | |
| transcript = transcript_processor.get_transcript() | |
| speaker_mapping = transcript_processor.speaker_mapping | |
| client = OpenAI() | |
| if "localhost" in origin: | |
| link_start = "http" | |
| else: | |
| link_start = "https" | |
| if ct == "si": # street interview | |
| user_prompt = get_street_interview_prompt(transcript, uid, rsid, link_start) | |
| system_prompt = get_street_interview_system_prompt(cid, rsid, origin, ct) | |
| completion = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| stream=True, | |
| temperature=0.1, | |
| ) | |
| else: | |
| system_prompt = get_live_event_system_prompt( | |
| cid, rsid, origin, ct, speaker_mapping, transcript | |
| ) | |
| user_prompt = get_live_event_user_prompt(uid, link_start) | |
| completion = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| stream=True, | |
| temperature=0.1, | |
| ) | |
| collected_messages = [] | |
| # Iterate through the stream | |
| for chunk in completion: | |
| if chunk.choices[0].delta.content is not None: | |
| chunk_message = chunk.choices[0].delta.content | |
| collected_messages.append(chunk_message) | |
| # Yield the accumulated message so far | |
| yield "".join(collected_messages) | |
| except Exception as e: | |
| print(f"Error in initial analysis: {str(e)}") | |
| yield "An error occurred during initial analysis. Please check your API key and file path." | |
| def chat( | |
| message: str, | |
| chat_history: List, | |
| transcript_processor: TranscriptProcessor, | |
| cid, | |
| rsid, | |
| origin, | |
| ct, | |
| uid, | |
| ): | |
| try: | |
| client = OpenAI() | |
| if "localhost" in origin: | |
| link_start = "http" | |
| else: | |
| link_start = "https" | |
| speaker_mapping = transcript_processor.speaker_mapping | |
| system_prompt = get_chat_system_prompt( | |
| cid=cid, | |
| rsid=rsid, | |
| origin=origin, | |
| ct=ct, | |
| speaker_mapping=speaker_mapping, | |
| transcript=transcript_processor.get_transcript(), | |
| link_start=link_start, | |
| ) | |
| messages = [{"role": "system", "content": system_prompt}] | |
| for user_msg, assistant_msg in chat_history: | |
| if user_msg is not None: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg is not None: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add the current message | |
| messages.append({"role": "user", "content": message}) | |
| completion = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=messages, | |
| tools=tools, | |
| stream=True, | |
| temperature=0.3, | |
| ) | |
| collected_messages = [] | |
| tool_calls_detected = False | |
| for chunk in completion: | |
| if chunk.choices[0].delta.tool_calls: | |
| tool_calls_detected = True | |
| # Handle tool calls without streaming | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=messages, | |
| tools=tools, | |
| ) | |
| if response.choices[0].message.tool_calls: | |
| tool_call = response.choices[0].message.tool_calls[0] | |
| if tool_call.function.name == "get_image": | |
| # Return the image directly in the chat | |
| image_data = get_image_crop(cid, rsid, uid) | |
| print(response.choices[0].message) | |
| messages.append(response.choices[0].message) | |
| function_call_result_message = { | |
| "role": "tool", | |
| "content": "Here are the Image Crops", | |
| "name": tool_call.function.name, | |
| "tool_call_id": tool_call.id, | |
| } | |
| messages.append(function_call_result_message) | |
| yield image_data | |
| return | |
| if tool_call.function.name == "correct_speaker_name_with_url": | |
| args = eval(tool_call.function.arguments) | |
| url = args.get("url", None) | |
| if url: | |
| transcript_processor.correct_speaker_mapping_with_agenda( | |
| url | |
| ) | |
| corrected_speaker_mapping = ( | |
| transcript_processor.speaker_mapping | |
| ) | |
| messages.append(response.choices[0].message) | |
| function_call_result_message = { | |
| "role": "tool", | |
| "content": json.dumps( | |
| { | |
| "speaker_mapping": f"Corrected Speaker Mapping... {corrected_speaker_mapping}" | |
| } | |
| ), | |
| "name": tool_call.function.name, | |
| "tool_call_id": tool_call.id, | |
| } | |
| messages.append(function_call_result_message) | |
| # Get final response after tool call | |
| final_response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=messages, | |
| stream=True, | |
| ) | |
| collected_chunk = "" | |
| for final_chunk in final_response: | |
| if final_chunk.choices[0].delta.content: | |
| collected_chunk += final_chunk.choices[ | |
| 0 | |
| ].delta.content | |
| yield collected_chunk | |
| return | |
| else: | |
| function_call_result_message = { | |
| "role": "tool", | |
| "content": "No URL Provided", | |
| "name": tool_call.function.name, | |
| "tool_call_id": tool_call.id, | |
| } | |
| elif tool_call.function.name == "correct_call_type": | |
| args = eval(tool_call.function.arguments) | |
| call_type = args.get("call_type", None) | |
| if call_type: | |
| # Stream the analysis for corrected call type | |
| for content in get_initial_analysis( | |
| transcript_processor, | |
| call_type, | |
| rsid, | |
| origin, | |
| call_type, | |
| uid, | |
| ): | |
| yield content | |
| return | |
| break # Exit streaming loop if tool calls detected | |
| if not tool_calls_detected and chunk.choices[0].delta.content is not None: | |
| chunk_message = chunk.choices[0].delta.content | |
| collected_messages.append(chunk_message) | |
| yield "".join(collected_messages) | |
| except Exception as e: | |
| print(f"Unexpected error in chat: {str(e)}") | |
| import traceback | |
| print(f"Traceback: {traceback.format_exc()}") | |
| yield "Sorry, there was an error processing your request." | |
| def create_chat_interface(): | |
| """Create and configure the chat interface.""" | |
| with gr.Blocks( | |
| fill_height=True, | |
| fill_width=True, | |
| css=css, | |
| head=head, | |
| theme=gr.themes.Default( | |
| font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"] | |
| ), | |
| ) as demo: | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot_box", | |
| layout="bubble", | |
| show_label=False, | |
| show_share_button=False, | |
| show_copy_all_button=False, | |
| show_copy_button=False, | |
| render=True, | |
| ) | |
| msg = gr.Textbox(elem_id="chatbot_textbox", show_label=False) | |
| transcript_processor_state = gr.State() # maintain state of imp things | |
| call_id_state = gr.State() | |
| colab_id_state = gr.State() | |
| origin_state = gr.State() | |
| ct_state = gr.State() | |
| turl_state = gr.State() | |
| uid_state = gr.State() | |
| iframe_html = "<iframe id='link-frame'></iframe>" | |
| gr.HTML(value=iframe_html) # Add iframe to the UI | |
| def respond( | |
| message: str, | |
| chat_history: List, | |
| transcript_processor, | |
| cid, | |
| rsid, | |
| origin, | |
| ct, | |
| uid, | |
| ): | |
| if not transcript_processor: | |
| bot_message = "Transcript processor not initialized." | |
| chat_history.append((message, bot_message)) | |
| return "", chat_history | |
| chat_history.append((message, "")) | |
| for chunk in chat( | |
| message, | |
| chat_history[:-1], # Exclude the current incomplete message | |
| transcript_processor, | |
| cid, | |
| rsid, | |
| origin, | |
| ct, | |
| uid, | |
| ): | |
| chat_history[-1] = (message, chunk) | |
| yield "", chat_history | |
| msg.submit( | |
| respond, | |
| [ | |
| msg, | |
| chatbot, | |
| transcript_processor_state, | |
| call_id_state, | |
| colab_id_state, | |
| origin_state, | |
| ct_state, | |
| uid_state, | |
| ], | |
| [msg, chatbot], | |
| ) | |
| # Handle initial loading with streaming | |
| def on_app_load(request: gr.Request): | |
| turls = None | |
| cid = request.query_params.get("cid", None) | |
| rsid = request.query_params.get("rsid", None) | |
| origin = request.query_params.get("origin", None) | |
| ct = request.query_params.get("ct", None) | |
| turl = request.query_params.get("turl", None) | |
| uid = request.query_params.get("uid", None) | |
| pnames = request.query_params.get("pnames", None) | |
| required_params = ["cid", "rsid", "origin", "ct", "turl", "uid"] | |
| missing_params = [ | |
| param | |
| for param in required_params | |
| if request.query_params.get(param) is None | |
| ] | |
| if missing_params: | |
| error_message = ( | |
| f"Missing required parameters: {', '.join(missing_params)}" | |
| ) | |
| chatbot_value = [(None, error_message)] | |
| return [chatbot_value, None, None, None, None, None, None, None] | |
| if ct == "rp": | |
| # split turls based on , | |
| turls = turl.split(",") | |
| pnames = [pname.replace("_", " ") for pname in pnames.split(",")] | |
| try: | |
| if turls: | |
| transcript_data = [] | |
| for turl in turls: | |
| print("Getting Transcript for URL") | |
| transcript_data.append(get_transcript_for_url(turl)) | |
| print("Now creating Processor") | |
| transcript_processor = TranscriptProcessor( | |
| transcript_data=transcript_data, | |
| call_type=ct, | |
| person_names=pnames, | |
| ) | |
| else: | |
| transcript_data = get_transcript_for_url(turl) | |
| transcript_processor = TranscriptProcessor( | |
| transcript_data=transcript_data, call_type=ct | |
| ) | |
| # Initialize with empty message | |
| chatbot_value = [(None, "")] | |
| # Return initial values with the transcript processor | |
| return [ | |
| chatbot_value, | |
| transcript_processor, | |
| cid, | |
| rsid, | |
| origin, | |
| ct, | |
| turl, | |
| uid, | |
| ] | |
| except Exception as e: | |
| print(e) | |
| error_message = f"Error processing call_id {cid}: {str(e)}" | |
| chatbot_value = [(None, error_message)] | |
| return [chatbot_value, None, None, None, None, None, None, None] | |
| def display_processing_message(chatbot_value): | |
| """Display the processing message while maintaining state.""" | |
| # Create new chatbot value with processing message | |
| new_chatbot_value = [ | |
| (None, "Video is being processed. Please wait for the results...") | |
| ] | |
| # Return all states to maintain them | |
| return new_chatbot_value | |
| def stream_initial_analysis( | |
| chatbot_value, transcript_processor, cid, rsid, origin, ct, uid | |
| ): | |
| if not transcript_processor: | |
| return chatbot_value | |
| try: | |
| for chunk in get_initial_analysis( | |
| transcript_processor, cid, rsid, origin, ct, uid | |
| ): | |
| # Update the existing message instead of creating a new one | |
| chatbot_value[0] = (None, chunk) | |
| yield chatbot_value | |
| except Exception as e: | |
| chatbot_value[0] = (None, f"Error during analysis: {str(e)}") | |
| yield chatbot_value | |
| demo.load( | |
| on_app_load, | |
| inputs=None, | |
| outputs=[ | |
| chatbot, | |
| transcript_processor_state, | |
| call_id_state, | |
| colab_id_state, | |
| origin_state, | |
| ct_state, | |
| turl_state, | |
| uid_state, | |
| ], | |
| ).then( | |
| display_processing_message, | |
| inputs=[chatbot], | |
| outputs=[chatbot], | |
| ).then( | |
| stream_initial_analysis, | |
| inputs=[ | |
| chatbot, | |
| transcript_processor_state, | |
| call_id_state, | |
| colab_id_state, | |
| origin_state, | |
| ct_state, | |
| uid_state, | |
| ], | |
| outputs=[chatbot], | |
| ) | |
| return demo | |
| def main(): | |
| """Main function to run the application.""" | |
| try: | |
| setup_openai_key() | |
| demo = create_chat_interface() | |
| demo.launch(share=True) | |
| except Exception as e: | |
| print(f"Error starting application: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| main() | |