AhmadMustafa's picture
update: show crops
9f1e459
raw
history blame
16.2 kB
import json
from typing import Generator, List
import gradio as gr
from openai import OpenAI
from crop_utils import get_image_crop
from prompts import (
get_chat_system_prompt,
get_live_event_system_prompt,
get_live_event_user_prompt,
get_street_interview_prompt,
get_street_interview_system_prompt,
)
from transcript import TranscriptProcessor
from utils import css, get_transcript_for_url, head
from utils import openai_tools as tools
from utils import setup_openai_key
client = OpenAI()
def get_initial_analysis(
transcript_processor: TranscriptProcessor, cid, rsid, origin, ct, uid
) -> Generator[str, None, None]:
"""Perform initial analysis of the transcript using OpenAI."""
try:
transcript = transcript_processor.get_transcript()
speaker_mapping = transcript_processor.speaker_mapping
client = OpenAI()
if "localhost" in origin:
link_start = "http"
else:
link_start = "https"
if ct == "si": # street interview
user_prompt = get_street_interview_prompt(transcript, uid, rsid, link_start)
system_prompt = get_street_interview_system_prompt(cid, rsid, origin, ct)
completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
stream=True,
temperature=0.1,
)
else:
system_prompt = get_live_event_system_prompt(
cid, rsid, origin, ct, speaker_mapping, transcript
)
user_prompt = get_live_event_user_prompt(uid, link_start)
completion = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
stream=True,
temperature=0.1,
)
collected_messages = []
# Iterate through the stream
for chunk in completion:
if chunk.choices[0].delta.content is not None:
chunk_message = chunk.choices[0].delta.content
collected_messages.append(chunk_message)
# Yield the accumulated message so far
yield "".join(collected_messages)
except Exception as e:
print(f"Error in initial analysis: {str(e)}")
yield "An error occurred during initial analysis. Please check your API key and file path."
def chat(
message: str,
chat_history: List,
transcript_processor: TranscriptProcessor,
cid,
rsid,
origin,
ct,
uid,
):
try:
client = OpenAI()
if "localhost" in origin:
link_start = "http"
else:
link_start = "https"
speaker_mapping = transcript_processor.speaker_mapping
system_prompt = get_chat_system_prompt(
cid=cid,
rsid=rsid,
origin=origin,
ct=ct,
speaker_mapping=speaker_mapping,
transcript=transcript_processor.get_transcript(),
link_start=link_start,
)
messages = [{"role": "system", "content": system_prompt}]
for user_msg, assistant_msg in chat_history:
if user_msg is not None:
messages.append({"role": "user", "content": user_msg})
if assistant_msg is not None:
messages.append({"role": "assistant", "content": assistant_msg})
# Add the current message
messages.append({"role": "user", "content": message})
completion = client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
stream=True,
temperature=0.3,
)
collected_messages = []
tool_calls_detected = False
for chunk in completion:
if chunk.choices[0].delta.tool_calls:
tool_calls_detected = True
# Handle tool calls without streaming
response = client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
)
if response.choices[0].message.tool_calls:
tool_call = response.choices[0].message.tool_calls[0]
if tool_call.function.name == "get_image":
# Return the image directly in the chat
image_data = get_image_crop(cid, rsid, uid)
messages.append(response.choices[0].message)
function_call_result_message = {
"role": "tool",
"content": "Here are the Image Crops",
"name": tool_call.function.name,
"tool_call_id": tool_call.id,
}
messages.append(function_call_result_message)
yield image_data
return
if tool_call.function.name == "correct_speaker_name_with_url":
args = eval(tool_call.function.arguments)
url = args.get("url", None)
if url:
transcript_processor.correct_speaker_mapping_with_agenda(
url
)
corrected_speaker_mapping = (
transcript_processor.speaker_mapping
)
messages.append(response.choices[0].message)
function_call_result_message = {
"role": "tool",
"content": json.dumps(
{
"speaker_mapping": f"Corrected Speaker Mapping... {corrected_speaker_mapping}"
}
),
"name": tool_call.function.name,
"tool_call_id": tool_call.id,
}
messages.append(function_call_result_message)
# Get final response after tool call
final_response = client.chat.completions.create(
model="gpt-4o",
messages=messages,
stream=True,
)
collected_chunk = ""
for final_chunk in final_response:
if final_chunk.choices[0].delta.content:
collected_chunk += final_chunk.choices[
0
].delta.content
yield collected_chunk
return
else:
function_call_result_message = {
"role": "tool",
"content": "No URL Provided",
"name": tool_call.function.name,
"tool_call_id": tool_call.id,
}
elif tool_call.function.name == "correct_call_type":
args = eval(tool_call.function.arguments)
call_type = args.get("call_type", None)
if call_type:
# Stream the analysis for corrected call type
for content in get_initial_analysis(
transcript_processor,
call_type,
rsid,
origin,
call_type,
uid,
):
yield content
return
break # Exit streaming loop if tool calls detected
if not tool_calls_detected and chunk.choices[0].delta.content is not None:
chunk_message = chunk.choices[0].delta.content
collected_messages.append(chunk_message)
yield "".join(collected_messages)
except Exception as e:
print(f"Unexpected error in chat: {str(e)}")
import traceback
print(f"Traceback: {traceback.format_exc()}")
yield "Sorry, there was an error processing your request."
def create_chat_interface():
"""Create and configure the chat interface."""
with gr.Blocks(
fill_height=True,
fill_width=True,
css=css,
head=head,
theme=gr.themes.Default(
font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]
),
) as demo:
chatbot = gr.Chatbot(
elem_id="chatbot_box",
layout="bubble",
show_label=False,
show_share_button=False,
show_copy_all_button=False,
show_copy_button=False,
render=True,
)
msg = gr.Textbox(elem_id="chatbot_textbox", show_label=False)
transcript_processor_state = gr.State() # maintain state of imp things
call_id_state = gr.State()
colab_id_state = gr.State()
origin_state = gr.State()
ct_state = gr.State()
turl_state = gr.State()
uid_state = gr.State()
iframe_html = "<iframe id='link-frame'></iframe>"
gr.HTML(value=iframe_html) # Add iframe to the UI
def respond(
message: str,
chat_history: List,
transcript_processor,
cid,
rsid,
origin,
ct,
uid,
):
if not transcript_processor:
bot_message = "Transcript processor not initialized."
chat_history.append((message, bot_message))
return "", chat_history
chat_history.append((message, ""))
for chunk in chat(
message,
chat_history[:-1], # Exclude the current incomplete message
transcript_processor,
cid,
rsid,
origin,
ct,
uid,
):
chat_history[-1] = (message, chunk)
yield "", chat_history
msg.submit(
respond,
[
msg,
chatbot,
transcript_processor_state,
call_id_state,
colab_id_state,
origin_state,
ct_state,
uid_state,
],
[msg, chatbot],
)
# Handle initial loading with streaming
def on_app_load(request: gr.Request):
turls = None
cid = request.query_params.get("cid", None)
rsid = request.query_params.get("rsid", None)
origin = request.query_params.get("origin", None)
ct = request.query_params.get("ct", None)
turl = request.query_params.get("turl", None)
uid = request.query_params.get("uid", None)
pnames = request.query_params.get("pnames", None)
required_params = ["cid", "rsid", "origin", "ct", "turl", "uid"]
missing_params = [
param
for param in required_params
if request.query_params.get(param) is None
]
if missing_params:
error_message = (
f"Missing required parameters: {', '.join(missing_params)}"
)
chatbot_value = [(None, error_message)]
return [chatbot_value, None, None, None, None, None, None, None]
if ct == "rp":
# split turls based on ,
turls = turl.split(",")
pnames = [pname.replace("_", " ") for pname in pnames.split(",")]
try:
if turls:
transcript_data = []
for turl in turls:
print("Getting Transcript for URL")
transcript_data.append(get_transcript_for_url(turl))
print("Now creating Processor")
transcript_processor = TranscriptProcessor(
transcript_data=transcript_data,
call_type=ct,
person_names=pnames,
)
else:
transcript_data = get_transcript_for_url(turl)
transcript_processor = TranscriptProcessor(
transcript_data=transcript_data, call_type=ct
)
# Initialize with empty message
chatbot_value = [(None, "")]
# Return initial values with the transcript processor
return [
chatbot_value,
transcript_processor,
cid,
rsid,
origin,
ct,
turl,
uid,
]
except Exception as e:
print(e)
error_message = f"Error processing call_id {cid}: {str(e)}"
chatbot_value = [(None, error_message)]
return [chatbot_value, None, None, None, None, None, None, None]
def display_processing_message(chatbot_value):
"""Display the processing message while maintaining state."""
# Create new chatbot value with processing message
new_chatbot_value = [
(None, "Video is being processed. Please wait for the results...")
]
# Return all states to maintain them
return new_chatbot_value
def stream_initial_analysis(
chatbot_value, transcript_processor, cid, rsid, origin, ct, uid
):
if not transcript_processor:
return chatbot_value
try:
for chunk in get_initial_analysis(
transcript_processor, cid, rsid, origin, ct, uid
):
# Update the existing message instead of creating a new one
chatbot_value[0] = (None, chunk)
yield chatbot_value
except Exception as e:
chatbot_value[0] = (None, f"Error during analysis: {str(e)}")
yield chatbot_value
demo.load(
on_app_load,
inputs=None,
outputs=[
chatbot,
transcript_processor_state,
call_id_state,
colab_id_state,
origin_state,
ct_state,
turl_state,
uid_state,
],
).then(
display_processing_message,
inputs=[chatbot],
outputs=[chatbot],
)
# .then(
# stream_initial_analysis,
# inputs=[
# chatbot,
# transcript_processor_state,
# call_id_state,
# colab_id_state,
# origin_state,
# ct_state,
# uid_state,
# ],
# outputs=[chatbot],
# )
return demo
def main():
"""Main function to run the application."""
try:
setup_openai_key()
demo = create_chat_interface()
demo.launch(share=True)
except Exception as e:
print(f"Error starting application: {str(e)}")
raise
if __name__ == "__main__":
main()