import json from typing import Generator, List import gradio as gr from openai import OpenAI from crop_utils import get_image_crop from prompts import ( get_chat_system_prompt, get_live_event_system_prompt, get_live_event_user_prompt, get_street_interview_prompt, get_street_interview_system_prompt, ) from transcript import TranscriptProcessor from utils import css, get_transcript_for_url, head from utils import openai_tools as tools from utils import setup_openai_key client = OpenAI() def get_initial_analysis( transcript_processor: TranscriptProcessor, cid, rsid, origin, ct, uid ) -> Generator[str, None, None]: """Perform initial analysis of the transcript using OpenAI.""" try: transcript = transcript_processor.get_transcript() speaker_mapping = transcript_processor.speaker_mapping client = OpenAI() if "localhost" in origin: link_start = "http" else: link_start = "https" if ct == "si": # street interview user_prompt = get_street_interview_prompt(transcript, uid, rsid, link_start) system_prompt = get_street_interview_system_prompt(cid, rsid, origin, ct) completion = client.chat.completions.create( model="gpt-4o", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], stream=True, temperature=0.1, ) else: system_prompt = get_live_event_system_prompt( cid, rsid, origin, ct, speaker_mapping, transcript ) user_prompt = get_live_event_user_prompt(uid, link_start) completion = client.chat.completions.create( model="gpt-4o", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], stream=True, temperature=0.1, ) collected_messages = [] # Iterate through the stream for chunk in completion: if chunk.choices[0].delta.content is not None: chunk_message = chunk.choices[0].delta.content collected_messages.append(chunk_message) # Yield the accumulated message so far yield "".join(collected_messages) except Exception as e: print(f"Error in initial analysis: {str(e)}") yield "An error occurred during initial analysis. Please check your API key and file path." def chat( message: str, chat_history: List, transcript_processor: TranscriptProcessor, cid, rsid, origin, ct, uid, ): try: client = OpenAI() if "localhost" in origin: link_start = "http" else: link_start = "https" speaker_mapping = transcript_processor.speaker_mapping system_prompt = get_chat_system_prompt( cid=cid, rsid=rsid, origin=origin, ct=ct, speaker_mapping=speaker_mapping, transcript=transcript_processor.get_transcript(), link_start=link_start, ) messages = [{"role": "system", "content": system_prompt}] for user_msg, assistant_msg in chat_history: if user_msg is not None: messages.append({"role": "user", "content": user_msg}) if assistant_msg is not None: messages.append({"role": "assistant", "content": assistant_msg}) # Add the current message messages.append({"role": "user", "content": message}) completion = client.chat.completions.create( model="gpt-4o", messages=messages, tools=tools, stream=True, temperature=0.3, ) collected_messages = [] tool_calls_detected = False for chunk in completion: if chunk.choices[0].delta.tool_calls: tool_calls_detected = True # Handle tool calls without streaming response = client.chat.completions.create( model="gpt-4o", messages=messages, tools=tools, ) if response.choices[0].message.tool_calls: tool_call = response.choices[0].message.tool_calls[0] if tool_call.function.name == "get_image": # Return the image directly in the chat image_data = get_image_crop(cid, rsid, uid) messages.append(response.choices[0].message) function_call_result_message = { "role": "tool", "content": "Here are the Image Crops", "name": tool_call.function.name, "tool_call_id": tool_call.id, } messages.append(function_call_result_message) yield image_data return if tool_call.function.name == "correct_speaker_name_with_url": args = eval(tool_call.function.arguments) url = args.get("url", None) if url: transcript_processor.correct_speaker_mapping_with_agenda( url ) corrected_speaker_mapping = ( transcript_processor.speaker_mapping ) messages.append(response.choices[0].message) function_call_result_message = { "role": "tool", "content": json.dumps( { "speaker_mapping": f"Corrected Speaker Mapping... {corrected_speaker_mapping}" } ), "name": tool_call.function.name, "tool_call_id": tool_call.id, } messages.append(function_call_result_message) # Get final response after tool call final_response = client.chat.completions.create( model="gpt-4o", messages=messages, stream=True, ) collected_chunk = "" for final_chunk in final_response: if final_chunk.choices[0].delta.content: collected_chunk += final_chunk.choices[ 0 ].delta.content yield collected_chunk return else: function_call_result_message = { "role": "tool", "content": "No URL Provided", "name": tool_call.function.name, "tool_call_id": tool_call.id, } elif tool_call.function.name == "correct_call_type": args = eval(tool_call.function.arguments) call_type = args.get("call_type", None) if call_type: # Stream the analysis for corrected call type for content in get_initial_analysis( transcript_processor, call_type, rsid, origin, call_type, uid, ): yield content return break # Exit streaming loop if tool calls detected if not tool_calls_detected and chunk.choices[0].delta.content is not None: chunk_message = chunk.choices[0].delta.content collected_messages.append(chunk_message) yield "".join(collected_messages) except Exception as e: print(f"Unexpected error in chat: {str(e)}") import traceback print(f"Traceback: {traceback.format_exc()}") yield "Sorry, there was an error processing your request." def create_chat_interface(): """Create and configure the chat interface.""" with gr.Blocks( fill_height=True, fill_width=True, css=css, head=head, theme=gr.themes.Default( font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"] ), ) as demo: chatbot = gr.Chatbot( elem_id="chatbot_box", layout="bubble", show_label=False, show_share_button=False, show_copy_all_button=False, show_copy_button=False, render=True, ) msg = gr.Textbox(elem_id="chatbot_textbox", show_label=False) transcript_processor_state = gr.State() # maintain state of imp things call_id_state = gr.State() colab_id_state = gr.State() origin_state = gr.State() ct_state = gr.State() turl_state = gr.State() uid_state = gr.State() iframe_html = "" gr.HTML(value=iframe_html) # Add iframe to the UI def respond( message: str, chat_history: List, transcript_processor, cid, rsid, origin, ct, uid, ): if not transcript_processor: bot_message = "Transcript processor not initialized." chat_history.append((message, bot_message)) return "", chat_history chat_history.append((message, "")) for chunk in chat( message, chat_history[:-1], # Exclude the current incomplete message transcript_processor, cid, rsid, origin, ct, uid, ): chat_history[-1] = (message, chunk) yield "", chat_history msg.submit( respond, [ msg, chatbot, transcript_processor_state, call_id_state, colab_id_state, origin_state, ct_state, uid_state, ], [msg, chatbot], ) # Handle initial loading with streaming def on_app_load(request: gr.Request): turls = None cid = request.query_params.get("cid", None) rsid = request.query_params.get("rsid", None) origin = request.query_params.get("origin", None) ct = request.query_params.get("ct", None) turl = request.query_params.get("turl", None) uid = request.query_params.get("uid", None) pnames = request.query_params.get("pnames", None) required_params = ["cid", "rsid", "origin", "ct", "turl", "uid"] missing_params = [ param for param in required_params if request.query_params.get(param) is None ] if missing_params: error_message = ( f"Missing required parameters: {', '.join(missing_params)}" ) chatbot_value = [(None, error_message)] return [chatbot_value, None, None, None, None, None, None, None] if ct == "rp": # split turls based on , turls = turl.split(",") pnames = [pname.replace("_", " ") for pname in pnames.split(",")] try: if turls: transcript_data = [] for turl in turls: print("Getting Transcript for URL") transcript_data.append(get_transcript_for_url(turl)) print("Now creating Processor") transcript_processor = TranscriptProcessor( transcript_data=transcript_data, call_type=ct, person_names=pnames, ) else: transcript_data = get_transcript_for_url(turl) transcript_processor = TranscriptProcessor( transcript_data=transcript_data, call_type=ct ) # Initialize with empty message chatbot_value = [(None, "")] # Return initial values with the transcript processor return [ chatbot_value, transcript_processor, cid, rsid, origin, ct, turl, uid, ] except Exception as e: print(e) error_message = f"Error processing call_id {cid}: {str(e)}" chatbot_value = [(None, error_message)] return [chatbot_value, None, None, None, None, None, None, None] def display_processing_message(chatbot_value): """Display the processing message while maintaining state.""" # Create new chatbot value with processing message new_chatbot_value = [ (None, "Video is being processed. Please wait for the results...") ] # Return all states to maintain them return new_chatbot_value def stream_initial_analysis( chatbot_value, transcript_processor, cid, rsid, origin, ct, uid ): if not transcript_processor: return chatbot_value try: for chunk in get_initial_analysis( transcript_processor, cid, rsid, origin, ct, uid ): # Update the existing message instead of creating a new one chatbot_value[0] = (None, chunk) yield chatbot_value except Exception as e: chatbot_value[0] = (None, f"Error during analysis: {str(e)}") yield chatbot_value demo.load( on_app_load, inputs=None, outputs=[ chatbot, transcript_processor_state, call_id_state, colab_id_state, origin_state, ct_state, turl_state, uid_state, ], ).then( display_processing_message, inputs=[chatbot], outputs=[chatbot], ) # .then( # stream_initial_analysis, # inputs=[ # chatbot, # transcript_processor_state, # call_id_state, # colab_id_state, # origin_state, # ct_state, # uid_state, # ], # outputs=[chatbot], # ) return demo def main(): """Main function to run the application.""" try: setup_openai_key() demo = create_chat_interface() demo.launch(share=True) except Exception as e: print(f"Error starting application: {str(e)}") raise if __name__ == "__main__": main()