Spaces:
Sleeping
Sleeping
Commit
·
37b73c5
1
Parent(s):
768a878
update: add streamimng
Browse files
app.py
CHANGED
@@ -411,11 +411,21 @@ Rank topics based on their potential virality and engagement for social media cl
|
|
411 |
},
|
412 |
{"role": "user", "content": prompt},
|
413 |
],
|
|
|
414 |
)
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
except Exception as e:
|
417 |
print(f"Error in initial analysis: {str(e)}")
|
418 |
-
|
419 |
|
420 |
|
421 |
def chat(
|
@@ -498,63 +508,77 @@ In the URL, make sure that after RSID there is ? and then rest of the fields are
|
|
498 |
messages.append({"role": "user", "content": message})
|
499 |
|
500 |
completion = client.chat.completions.create(
|
501 |
-
model="gpt-4o-mini",
|
502 |
-
messages=messages,
|
503 |
-
tools=tools,
|
504 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
"
|
|
|
|
|
|
|
|
|
521 |
}
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
|
552 |
except Exception as e:
|
553 |
print(f"Unexpected error in chat: {str(e)}")
|
554 |
import traceback
|
555 |
|
556 |
print(f"Traceback: {traceback.format_exc()}")
|
557 |
-
|
558 |
|
559 |
|
560 |
def create_chat_interface():
|
@@ -688,18 +712,62 @@ def create_chat_interface():
|
|
688 |
iframe_html = "<iframe id='link-frame'></iframe>"
|
689 |
gr.HTML(value=iframe_html) # Add iframe to the UI
|
690 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
def on_app_load(request: gr.Request):
|
692 |
cid = request.query_params.get("cid", None)
|
693 |
rsid = request.query_params.get("rsid", None)
|
694 |
origin = request.query_params.get("origin", None)
|
695 |
ct = request.query_params.get("ct", None)
|
696 |
turl = request.query_params.get("turl", None)
|
|
|
697 |
required_params = ["cid", "rsid", "origin", "ct", "turl"]
|
698 |
missing_params = [
|
699 |
param
|
700 |
for param in required_params
|
701 |
if request.query_params.get(param) is None
|
702 |
]
|
|
|
703 |
if missing_params:
|
704 |
error_message = (
|
705 |
f"Missing required parameters: {', '.join(missing_params)}"
|
@@ -715,33 +783,16 @@ def create_chat_interface():
|
|
715 |
None,
|
716 |
]
|
717 |
|
718 |
-
# if any param is missing, return error
|
719 |
-
if not cid or not rsid or not origin or not ct or not turl:
|
720 |
-
error_message = "Error processing"
|
721 |
-
chatbot_value = [(None, error_message)]
|
722 |
-
return [
|
723 |
-
chatbot_value,
|
724 |
-
None,
|
725 |
-
None,
|
726 |
-
None,
|
727 |
-
None,
|
728 |
-
None,
|
729 |
-
None,
|
730 |
-
]
|
731 |
-
|
732 |
try:
|
733 |
transcript_data = get_transcript_for_url(turl)
|
734 |
transcript_processor = TranscriptProcessor(
|
735 |
transcript_data=transcript_data
|
736 |
)
|
737 |
-
initial_analysis = get_initial_analysis(
|
738 |
-
transcript_processor, cid, rsid, origin, ct
|
739 |
-
)
|
740 |
|
741 |
-
|
742 |
-
|
743 |
-
] # initialized with initial analysis and assistant is None
|
744 |
|
|
|
745 |
return [
|
746 |
chatbot_value,
|
747 |
transcript_processor,
|
@@ -764,6 +815,19 @@ def create_chat_interface():
|
|
764 |
None,
|
765 |
]
|
766 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
767 |
demo.load(
|
768 |
on_app_load,
|
769 |
inputs=None,
|
@@ -776,36 +840,9 @@ def create_chat_interface():
|
|
776 |
ct_state,
|
777 |
turl_state,
|
778 |
],
|
779 |
-
)
|
780 |
-
|
781 |
-
|
782 |
-
message: str,
|
783 |
-
chat_history: List,
|
784 |
-
transcript_processor,
|
785 |
-
cid,
|
786 |
-
rsid,
|
787 |
-
origin,
|
788 |
-
ct,
|
789 |
-
):
|
790 |
-
if not transcript_processor:
|
791 |
-
bot_message = "Transcript processor not initialized."
|
792 |
-
else:
|
793 |
-
bot_message = chat(
|
794 |
-
message,
|
795 |
-
chat_history,
|
796 |
-
transcript_processor,
|
797 |
-
cid,
|
798 |
-
rsid,
|
799 |
-
origin,
|
800 |
-
ct,
|
801 |
-
)
|
802 |
-
chat_history.append((message, bot_message))
|
803 |
-
return "", chat_history
|
804 |
-
|
805 |
-
msg.submit(
|
806 |
-
respond,
|
807 |
-
[
|
808 |
-
msg,
|
809 |
chatbot,
|
810 |
transcript_processor_state,
|
811 |
call_id_state,
|
@@ -813,7 +850,7 @@ def create_chat_interface():
|
|
813 |
origin_state,
|
814 |
ct_state,
|
815 |
],
|
816 |
-
[
|
817 |
)
|
818 |
|
819 |
return demo
|
|
|
411 |
},
|
412 |
{"role": "user", "content": prompt},
|
413 |
],
|
414 |
+
stream=True,
|
415 |
)
|
416 |
+
|
417 |
+
collected_messages = []
|
418 |
+
# Iterate through the stream
|
419 |
+
for chunk in completion:
|
420 |
+
if chunk.choices[0].delta.content is not None:
|
421 |
+
chunk_message = chunk.choices[0].delta.content
|
422 |
+
collected_messages.append(chunk_message)
|
423 |
+
# Yield the accumulated message so far
|
424 |
+
yield "".join(collected_messages)
|
425 |
+
|
426 |
except Exception as e:
|
427 |
print(f"Error in initial analysis: {str(e)}")
|
428 |
+
yield "An error occurred during initial analysis. Please check your API key and file path."
|
429 |
|
430 |
|
431 |
def chat(
|
|
|
508 |
messages.append({"role": "user", "content": message})
|
509 |
|
510 |
completion = client.chat.completions.create(
|
511 |
+
model="gpt-4o-mini", messages=messages, tools=tools, stream=True
|
|
|
|
|
512 |
)
|
513 |
+
collected_messages = []
|
514 |
+
tool_calls_detected = False
|
515 |
+
|
516 |
+
for chunk in completion:
|
517 |
+
if chunk.choices[0].delta.tool_calls:
|
518 |
+
tool_calls_detected = True
|
519 |
+
# Handle tool calls without streaming
|
520 |
+
response = client.chat.completions.create(
|
521 |
+
model="gpt-4o-mini",
|
522 |
+
messages=messages,
|
523 |
+
tools=tools,
|
524 |
+
)
|
525 |
|
526 |
+
if response.choices[0].message.tool_calls:
|
527 |
+
tool_call = response.choices[0].message.tool_calls[0]
|
528 |
+
if tool_call.function.name == "correct_speaker_name_with_url":
|
529 |
+
args = eval(tool_call.function.arguments)
|
530 |
+
url = args.get("url", None)
|
531 |
+
if url:
|
532 |
+
transcript_processor.correct_speaker_mapping_with_agenda(
|
533 |
+
url
|
534 |
+
)
|
535 |
+
corrected_speaker_mapping = (
|
536 |
+
transcript_processor.speaker_mapping
|
537 |
+
)
|
538 |
+
function_call_result_message = {
|
539 |
+
"role": "tool",
|
540 |
+
"content": json.dumps(
|
541 |
+
{"speaker_mapping": f"Corrected Speaker Mapping..."}
|
542 |
+
),
|
543 |
+
"name": tool_call.function.name,
|
544 |
+
"tool_call_id": tool_call.id,
|
545 |
}
|
546 |
+
messages.append(function_call_result_message)
|
547 |
+
|
548 |
+
# Get final response after tool call
|
549 |
+
final_response = client.chat.completions.create(
|
550 |
+
model="gpt-4o-mini", messages=messages, stream=True
|
551 |
+
)
|
552 |
+
|
553 |
+
# Stream the final response
|
554 |
+
for final_chunk in final_response:
|
555 |
+
if final_chunk.choices[0].delta.content:
|
556 |
+
yield final_chunk.choices[0].delta.content
|
557 |
+
return
|
558 |
+
|
559 |
+
elif tool_call.function.name == "correct_call_type":
|
560 |
+
args = eval(tool_call.function.arguments)
|
561 |
+
call_type = args.get("call_type", None)
|
562 |
+
if call_type:
|
563 |
+
# Stream the analysis for corrected call type
|
564 |
+
for content in get_initial_analysis(
|
565 |
+
transcript_processor, call_type, rsid, origin, call_type
|
566 |
+
):
|
567 |
+
yield content
|
568 |
+
return
|
569 |
+
break # Exit streaming loop if tool calls detected
|
570 |
+
|
571 |
+
if not tool_calls_detected and chunk.choices[0].delta.content is not None:
|
572 |
+
chunk_message = chunk.choices[0].delta.content
|
573 |
+
collected_messages.append(chunk_message)
|
574 |
+
yield "".join(collected_messages)
|
575 |
|
576 |
except Exception as e:
|
577 |
print(f"Unexpected error in chat: {str(e)}")
|
578 |
import traceback
|
579 |
|
580 |
print(f"Traceback: {traceback.format_exc()}")
|
581 |
+
yield "Sorry, there was an error processing your request."
|
582 |
|
583 |
|
584 |
def create_chat_interface():
|
|
|
712 |
iframe_html = "<iframe id='link-frame'></iframe>"
|
713 |
gr.HTML(value=iframe_html) # Add iframe to the UI
|
714 |
|
715 |
+
def respond(
|
716 |
+
message: str,
|
717 |
+
chat_history: List,
|
718 |
+
transcript_processor,
|
719 |
+
cid,
|
720 |
+
rsid,
|
721 |
+
origin,
|
722 |
+
ct,
|
723 |
+
):
|
724 |
+
if not transcript_processor:
|
725 |
+
bot_message = "Transcript processor not initialized."
|
726 |
+
chat_history.append((message, bot_message))
|
727 |
+
return "", chat_history
|
728 |
+
|
729 |
+
chat_history.append((message, ""))
|
730 |
+
for chunk in chat(
|
731 |
+
message,
|
732 |
+
chat_history[:-1], # Exclude the current incomplete message
|
733 |
+
transcript_processor,
|
734 |
+
cid,
|
735 |
+
rsid,
|
736 |
+
origin,
|
737 |
+
ct,
|
738 |
+
):
|
739 |
+
chat_history[-1] = (message, chunk)
|
740 |
+
yield "", chat_history
|
741 |
+
|
742 |
+
msg.submit(
|
743 |
+
respond,
|
744 |
+
[
|
745 |
+
msg,
|
746 |
+
chatbot,
|
747 |
+
transcript_processor_state,
|
748 |
+
call_id_state,
|
749 |
+
colab_id_state,
|
750 |
+
origin_state,
|
751 |
+
ct_state,
|
752 |
+
],
|
753 |
+
[msg, chatbot],
|
754 |
+
)
|
755 |
+
|
756 |
+
# Handle initial loading with streaming
|
757 |
def on_app_load(request: gr.Request):
|
758 |
cid = request.query_params.get("cid", None)
|
759 |
rsid = request.query_params.get("rsid", None)
|
760 |
origin = request.query_params.get("origin", None)
|
761 |
ct = request.query_params.get("ct", None)
|
762 |
turl = request.query_params.get("turl", None)
|
763 |
+
|
764 |
required_params = ["cid", "rsid", "origin", "ct", "turl"]
|
765 |
missing_params = [
|
766 |
param
|
767 |
for param in required_params
|
768 |
if request.query_params.get(param) is None
|
769 |
]
|
770 |
+
|
771 |
if missing_params:
|
772 |
error_message = (
|
773 |
f"Missing required parameters: {', '.join(missing_params)}"
|
|
|
783 |
None,
|
784 |
]
|
785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
786 |
try:
|
787 |
transcript_data = get_transcript_for_url(turl)
|
788 |
transcript_processor = TranscriptProcessor(
|
789 |
transcript_data=transcript_data
|
790 |
)
|
|
|
|
|
|
|
791 |
|
792 |
+
# Initialize with empty message
|
793 |
+
chatbot_value = [(None, "")]
|
|
|
794 |
|
795 |
+
# Return initial values with the transcript processor
|
796 |
return [
|
797 |
chatbot_value,
|
798 |
transcript_processor,
|
|
|
815 |
None,
|
816 |
]
|
817 |
|
818 |
+
def stream_initial_analysis(
|
819 |
+
chatbot_value, transcript_processor, cid, rsid, origin, ct
|
820 |
+
):
|
821 |
+
if transcript_processor:
|
822 |
+
for chunk in get_initial_analysis(
|
823 |
+
transcript_processor, cid, rsid, origin, ct
|
824 |
+
):
|
825 |
+
chatbot_value[0] = (None, chunk)
|
826 |
+
yield chatbot_value
|
827 |
+
else:
|
828 |
+
yield chatbot_value
|
829 |
+
|
830 |
+
# Modified load event to handle streaming
|
831 |
demo.load(
|
832 |
on_app_load,
|
833 |
inputs=None,
|
|
|
840 |
ct_state,
|
841 |
turl_state,
|
842 |
],
|
843 |
+
).then(
|
844 |
+
stream_initial_analysis,
|
845 |
+
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
chatbot,
|
847 |
transcript_processor_state,
|
848 |
call_id_state,
|
|
|
850 |
origin_state,
|
851 |
ct_state,
|
852 |
],
|
853 |
+
outputs=[chatbot],
|
854 |
)
|
855 |
|
856 |
return demo
|