Spaces:
Sleeping
Sleeping
import json | |
from typing import Generator, List | |
import gradio as gr | |
from openai import OpenAI | |
from crop_utils import get_image_crop | |
from prompts import ( | |
get_chat_system_prompt, | |
get_live_event_system_prompt, | |
get_live_event_user_prompt, | |
get_street_interview_prompt, | |
get_street_interview_system_prompt, | |
) | |
from transcript import TranscriptProcessor | |
from utils import css, get_transcript_for_url, head | |
from utils import openai_tools as tools | |
from utils import setup_openai_key | |
client = OpenAI() | |
def get_initial_analysis( | |
transcript_processor: TranscriptProcessor, cid, rsid, origin, ct, uid | |
) -> Generator[str, None, None]: | |
"""Perform initial analysis of the transcript using OpenAI.""" | |
try: | |
transcript = transcript_processor.get_transcript() | |
speaker_mapping = transcript_processor.speaker_mapping | |
client = OpenAI() | |
if "localhost" in origin: | |
link_start = "http" | |
else: | |
link_start = "https" | |
if ct == "si": # street interview | |
user_prompt = get_street_interview_prompt(transcript, uid, rsid, link_start) | |
system_prompt = get_street_interview_system_prompt(cid, rsid, origin, ct) | |
completion = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
], | |
stream=True, | |
temperature=0.1, | |
) | |
else: | |
system_prompt = get_live_event_system_prompt( | |
cid, rsid, origin, ct, speaker_mapping, transcript | |
) | |
user_prompt = get_live_event_user_prompt(uid, link_start) | |
completion = client.chat.completions.create( | |
model="gpt-4o", | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
], | |
stream=True, | |
temperature=0.1, | |
) | |
collected_messages = [] | |
# Iterate through the stream | |
for chunk in completion: | |
if chunk.choices[0].delta.content is not None: | |
chunk_message = chunk.choices[0].delta.content | |
collected_messages.append(chunk_message) | |
# Yield the accumulated message so far | |
yield "".join(collected_messages) | |
except Exception as e: | |
print(f"Error in initial analysis: {str(e)}") | |
yield "An error occurred during initial analysis. Please check your API key and file path." | |
def chat( | |
message: str, | |
chat_history: List, | |
transcript_processor: TranscriptProcessor, | |
cid, | |
rsid, | |
origin, | |
ct, | |
uid, | |
): | |
try: | |
client = OpenAI() | |
if "localhost" in origin: | |
link_start = "http" | |
else: | |
link_start = "https" | |
speaker_mapping = transcript_processor.speaker_mapping | |
system_prompt = get_chat_system_prompt( | |
cid=cid, | |
rsid=rsid, | |
origin=origin, | |
ct=ct, | |
speaker_mapping=speaker_mapping, | |
transcript=transcript_processor.get_transcript(), | |
link_start=link_start, | |
) | |
messages = [{"role": "system", "content": system_prompt}] | |
for user_msg, assistant_msg in chat_history: | |
if user_msg is not None: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg is not None: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add the current message | |
messages.append({"role": "user", "content": message}) | |
completion = client.chat.completions.create( | |
model="gpt-4o", | |
messages=messages, | |
tools=tools, | |
stream=True, | |
temperature=0.3, | |
) | |
collected_messages = [] | |
tool_calls_detected = False | |
for chunk in completion: | |
if chunk.choices[0].delta.tool_calls: | |
tool_calls_detected = True | |
# Handle tool calls without streaming | |
response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=messages, | |
tools=tools, | |
) | |
if response.choices[0].message.tool_calls: | |
tool_call = response.choices[0].message.tool_calls[0] | |
if tool_call.function.name == "get_image": | |
# Return the image directly in the chat | |
image_data = get_image_crop(cid, rsid, uid) | |
messages.append(response.choices[0].message) | |
function_call_result_message = { | |
"role": "tool", | |
"content": "Here are the Image Crops", | |
"name": tool_call.function.name, | |
"tool_call_id": tool_call.id, | |
} | |
messages.append(function_call_result_message) | |
yield image_data | |
return | |
if tool_call.function.name == "correct_speaker_name_with_url": | |
args = eval(tool_call.function.arguments) | |
url = args.get("url", None) | |
if url: | |
transcript_processor.correct_speaker_mapping_with_agenda( | |
url | |
) | |
corrected_speaker_mapping = ( | |
transcript_processor.speaker_mapping | |
) | |
messages.append(response.choices[0].message) | |
function_call_result_message = { | |
"role": "tool", | |
"content": json.dumps( | |
{ | |
"speaker_mapping": f"Corrected Speaker Mapping... {corrected_speaker_mapping}" | |
} | |
), | |
"name": tool_call.function.name, | |
"tool_call_id": tool_call.id, | |
} | |
messages.append(function_call_result_message) | |
# Get final response after tool call | |
final_response = client.chat.completions.create( | |
model="gpt-4o", | |
messages=messages, | |
stream=True, | |
) | |
collected_chunk = "" | |
for final_chunk in final_response: | |
if final_chunk.choices[0].delta.content: | |
collected_chunk += final_chunk.choices[ | |
0 | |
].delta.content | |
yield collected_chunk | |
return | |
else: | |
function_call_result_message = { | |
"role": "tool", | |
"content": "No URL Provided", | |
"name": tool_call.function.name, | |
"tool_call_id": tool_call.id, | |
} | |
elif tool_call.function.name == "correct_call_type": | |
args = eval(tool_call.function.arguments) | |
call_type = args.get("call_type", None) | |
if call_type: | |
# Stream the analysis for corrected call type | |
for content in get_initial_analysis( | |
transcript_processor, | |
call_type, | |
rsid, | |
origin, | |
call_type, | |
uid, | |
): | |
yield content | |
return | |
break # Exit streaming loop if tool calls detected | |
if not tool_calls_detected and chunk.choices[0].delta.content is not None: | |
chunk_message = chunk.choices[0].delta.content | |
collected_messages.append(chunk_message) | |
yield "".join(collected_messages) | |
except Exception as e: | |
print(f"Unexpected error in chat: {str(e)}") | |
import traceback | |
print(f"Traceback: {traceback.format_exc()}") | |
yield "Sorry, there was an error processing your request." | |
def create_chat_interface(): | |
"""Create and configure the chat interface.""" | |
with gr.Blocks( | |
fill_height=True, | |
fill_width=True, | |
css=css, | |
head=head, | |
theme=gr.themes.Default( | |
font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"] | |
), | |
) as demo: | |
chatbot = gr.Chatbot( | |
elem_id="chatbot_box", | |
layout="bubble", | |
show_label=False, | |
show_share_button=False, | |
show_copy_all_button=False, | |
show_copy_button=False, | |
render=True, | |
) | |
msg = gr.Textbox(elem_id="chatbot_textbox", show_label=False) | |
transcript_processor_state = gr.State() # maintain state of imp things | |
call_id_state = gr.State() | |
colab_id_state = gr.State() | |
origin_state = gr.State() | |
ct_state = gr.State() | |
turl_state = gr.State() | |
uid_state = gr.State() | |
iframe_html = "<iframe id='link-frame'></iframe>" | |
gr.HTML(value=iframe_html) # Add iframe to the UI | |
def respond( | |
message: str, | |
chat_history: List, | |
transcript_processor, | |
cid, | |
rsid, | |
origin, | |
ct, | |
uid, | |
): | |
if not transcript_processor: | |
bot_message = "Transcript processor not initialized." | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
chat_history.append((message, "")) | |
for chunk in chat( | |
message, | |
chat_history[:-1], # Exclude the current incomplete message | |
transcript_processor, | |
cid, | |
rsid, | |
origin, | |
ct, | |
uid, | |
): | |
chat_history[-1] = (message, chunk) | |
yield "", chat_history | |
msg.submit( | |
respond, | |
[ | |
msg, | |
chatbot, | |
transcript_processor_state, | |
call_id_state, | |
colab_id_state, | |
origin_state, | |
ct_state, | |
uid_state, | |
], | |
[msg, chatbot], | |
) | |
# Handle initial loading with streaming | |
def on_app_load(request: gr.Request): | |
turls = None | |
cid = request.query_params.get("cid", None) | |
rsid = request.query_params.get("rsid", None) | |
origin = request.query_params.get("origin", None) | |
ct = request.query_params.get("ct", None) | |
turl = request.query_params.get("turl", None) | |
uid = request.query_params.get("uid", None) | |
pnames = request.query_params.get("pnames", None) | |
required_params = ["cid", "rsid", "origin", "ct", "turl", "uid"] | |
missing_params = [ | |
param | |
for param in required_params | |
if request.query_params.get(param) is None | |
] | |
if missing_params: | |
error_message = ( | |
f"Missing required parameters: {', '.join(missing_params)}" | |
) | |
chatbot_value = [(None, error_message)] | |
return [chatbot_value, None, None, None, None, None, None, None] | |
if ct == "rp": | |
# split turls based on , | |
turls = turl.split(",") | |
pnames = [pname.replace("_", " ") for pname in pnames.split(",")] | |
try: | |
if turls: | |
transcript_data = [] | |
for turl in turls: | |
print("Getting Transcript for URL") | |
transcript_data.append(get_transcript_for_url(turl)) | |
print("Now creating Processor") | |
transcript_processor = TranscriptProcessor( | |
transcript_data=transcript_data, | |
call_type=ct, | |
person_names=pnames, | |
) | |
else: | |
transcript_data = get_transcript_for_url(turl) | |
transcript_processor = TranscriptProcessor( | |
transcript_data=transcript_data, call_type=ct | |
) | |
# Initialize with empty message | |
chatbot_value = [(None, "")] | |
# Return initial values with the transcript processor | |
return [ | |
chatbot_value, | |
transcript_processor, | |
cid, | |
rsid, | |
origin, | |
ct, | |
turl, | |
uid, | |
] | |
except Exception as e: | |
print(e) | |
error_message = f"Error processing call_id {cid}: {str(e)}" | |
chatbot_value = [(None, error_message)] | |
return [chatbot_value, None, None, None, None, None, None, None] | |
def display_processing_message(chatbot_value): | |
"""Display the processing message while maintaining state.""" | |
# Create new chatbot value with processing message | |
new_chatbot_value = [ | |
(None, "Video is being processed. Please wait for the results...") | |
] | |
# Return all states to maintain them | |
return new_chatbot_value | |
def stream_initial_analysis( | |
chatbot_value, transcript_processor, cid, rsid, origin, ct, uid | |
): | |
if not transcript_processor: | |
return chatbot_value | |
try: | |
for chunk in get_initial_analysis( | |
transcript_processor, cid, rsid, origin, ct, uid | |
): | |
# Update the existing message instead of creating a new one | |
chatbot_value[0] = (None, chunk) | |
yield chatbot_value | |
except Exception as e: | |
chatbot_value[0] = (None, f"Error during analysis: {str(e)}") | |
yield chatbot_value | |
demo.load( | |
on_app_load, | |
inputs=None, | |
outputs=[ | |
chatbot, | |
transcript_processor_state, | |
call_id_state, | |
colab_id_state, | |
origin_state, | |
ct_state, | |
turl_state, | |
uid_state, | |
], | |
).then( | |
display_processing_message, | |
inputs=[chatbot], | |
outputs=[chatbot], | |
) | |
# .then( | |
# stream_initial_analysis, | |
# inputs=[ | |
# chatbot, | |
# transcript_processor_state, | |
# call_id_state, | |
# colab_id_state, | |
# origin_state, | |
# ct_state, | |
# uid_state, | |
# ], | |
# outputs=[chatbot], | |
# ) | |
return demo | |
def main(): | |
"""Main function to run the application.""" | |
try: | |
setup_openai_key() | |
demo = create_chat_interface() | |
demo.launch(share=True) | |
except Exception as e: | |
print(f"Error starting application: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
main() | |